diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1574,11 +1574,14 @@ closely correspond to those of the `llvm.masked.load` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics). - Example: + Examples: ```mlir %0 = vector.maskedload %base[%i], %mask, %pass_thru : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + + %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1625,11 +1628,14 @@ closely correspond to those of the `llvm.masked.store` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics). - Example: + Examples: ```mlir vector.maskedstore %base[%i], %mask, %value : memref, vector<8xi1>, vector<8xf32> + + vector.maskedstore %base[%i, %j], %0, %1 + : memref, vector<16xi1>, vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1652,7 +1658,8 @@ def Vector_GatherOp : Vector_Op<"gather">, Arguments<(ins Arg:$base, - VectorOfRankAndType<[1], [AnyInteger]>:$indices, + Variadic:$indices, + VectorOfRankAndType<[1], [AnyInteger]>:$index, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$pass_thru)>, Results<(outs VectorOfRank<[1]>:$result)> { @@ -1661,9 +1668,10 @@ let description = [{ The gather operation gathers elements from memory into a 1-D vector as - defined by a base and a 1-D index vector, but only if the corresponding - bit is set in a 1-D mask vector. Otherwise, the element is taken from a - 1-D pass-through vector. Informally the semantics are: + defined by a base with indices with the last one a 1-D index vector, but + only if the corresponding bit is set in a 1-D mask vector. Otherwise, the + element is taken from a 1-D pass-through vector. Informally the semantics + are: ``` result[0] := mask[0] ? base[index[0]] : pass_thru[0] result[1] := mask[1] ? base[index[1]] : pass_thru[1] @@ -1677,19 +1685,22 @@ correspond to those of the `llvm.masked.gather` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics). - Example: + Examples: ```mlir - %g = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%v], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + + %1 = vector.gather %base[%i][%v], %mask, %pass_thru + : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return base().getType().cast(); } - VectorType getIndicesVectorType() { - return indices().getType().cast(); + VectorType getIndexVectorType() { + return index().getType().cast(); } VectorType getMaskVectorType() { return mask().getType().cast(); @@ -1701,15 +1712,16 @@ return result().getType().cast(); } }]; - let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " - "type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)"; + let assemblyFormat = "$base `[` $indices `]` `[` $index `]` `,` $mask `,` $pass_thru attr-dict `:` " + "type($base) `,` type($index) `,` type($mask) `,` type($pass_thru) `into` type($result)"; let hasCanonicalizer = 1; } def Vector_ScatterOp : Vector_Op<"scatter">, Arguments<(ins Arg:$base, - VectorOfRankAndType<[1], [AnyInteger]>:$indices, + Variadic:$indices, + VectorOfRankAndType<[1], [AnyInteger]>:$index, VectorOfRankAndType<[1], [I1]>:$mask, VectorOfRank<[1]>:$valueToStore)> { @@ -1717,9 +1729,9 @@ let description = [{ The scatter operation scatters elements from a 1-D vector into memory as - defined by a base and a 1-D index vector, but only if the corresponding - bit in a 1-D mask vector is set. Otherwise, no action is taken for that - element. Informally the semantics are: + defined by a base with indices with the last one a 1-D index vector, but + only if the corresponding bit in a 1-D mask vector is set. Otherwise, no + action is taken for that element. Informally the semantics are: ``` if (mask[0]) base[index[0]] = value[0] if (mask[1]) base[index[1]] = value[1] @@ -1736,19 +1748,22 @@ correspond to those of the `llvm.masked.scatter` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics). - Example: + Examples: ```mlir - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%v], %mask, %value : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + + vector.scatter %base[%i][%v], %mask, %value + : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> ``` }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return base().getType().cast(); } - VectorType getIndicesVectorType() { - return indices().getType().cast(); + VectorType getIndexVectorType() { + return index().getType().cast(); } VectorType getMaskVectorType() { return mask().getType().cast(); @@ -1758,8 +1773,8 @@ } }]; let assemblyFormat = - "$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` " - "type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)"; + "$base `[` $indices `]` `[` $index `]` `,` $mask `,` $valueToStore attr-dict `:` " + "type($base) `,` type($index) `,` type($mask) `,` type($valueToStore)"; let hasCanonicalizer = 1; } @@ -1792,11 +1807,14 @@ correspond to those of the `llvm.masked.expandload` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics). - Example: + Examples: ```mlir %0 = vector.expandload %base[%i], %mask, %pass_thru : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + + %1 = vector.expandload %base[%i, %j], %mask, %pass_thru + : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1846,11 +1864,14 @@ correspond to those of the `llvm.masked.compressstore` [intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics). - Example: + Examples: ```mlir vector.compressstore %base[%i], %mask, %value : memref, vector<8xi1>, vector<8xf32> + + vector.compressstore %base[%i, %j], %mask, %value + : memref, vector<16xi1>, vector<16xf32> ``` }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -178,34 +178,21 @@ return success(); } -// Helper that returns the base address of a memref. -static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, - Value memref, MemRefType memRefType, Value &base) { - // Inspect stride and offset structure. - // - // TODO: flat memory only for now, generalize - // +// Add an index vector component to a base pointer. This almost always succeeds +// unless the last stride is non-unit or the memory space is not zero. +static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, + Location loc, Value memref, Value base, + Value index, MemRefType memRefType, + VectorType vType, Value &ptrs) { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); - if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || - offset != 0 || memRefType.getMemorySpace() != 0) - return failure(); - base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); - return success(); -} - -// Helper that returns vector of pointers given a memref base with index vector. -static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, - Location loc, Value memref, Value indices, - MemRefType memRefType, VectorType vType, - Type iType, Value &ptrs) { - Value base; - if (failed(getBase(rewriter, loc, memref, memRefType, base))) + if (failed(successStrides) || strides.back() != 1 || + memRefType.getMemorySpace() != 0) return failure(); auto pType = MemRefDescriptor(memref).getElementPtrType(); auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0)); - ptrs = rewriter.create(loc, ptrsType, base, indices); + ptrs = rewriter.create(loc, ptrsType, base, index); return success(); } @@ -435,19 +422,20 @@ ConversionPatternRewriter &rewriter) const override { auto loc = gather->getLoc(); auto adaptor = vector::GatherOpAdaptor(operands); + MemRefType memRefType = gather.getMemRefType(); // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(), - align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); - // Get index ptrs. - VectorType vType = gather.getVectorType(); - Type iType = gather.getIndicesVectorType().getElementType(); + // Resolve address. Value ptrs; - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), - gather.getMemRefType(), vType, iType, ptrs))) + VectorType vType = gather.getVectorType(); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, + adaptor.index(), memRefType, vType, ptrs))) return failure(); // Replace with the gather intrinsic. @@ -469,19 +457,20 @@ ConversionPatternRewriter &rewriter) const override { auto loc = scatter->getLoc(); auto adaptor = vector::ScatterOpAdaptor(operands); + MemRefType memRefType = scatter.getMemRefType(); // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(), - align))) + if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) return failure(); - // Get index ptrs. - VectorType vType = scatter.getVectorType(); - Type iType = scatter.getIndicesVectorType().getElementType(); + // Resolve address. Value ptrs; - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), - scatter.getMemRefType(), vType, iType, ptrs))) + VectorType vType = scatter.getVectorType(); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, + adaptor.index(), memRefType, vType, ptrs))) return failure(); // Replace with the scatter intrinsic. @@ -507,8 +496,8 @@ // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); @@ -530,8 +519,8 @@ MemRefType memRefType = compress.getMemRefType(); // Resolve address. - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( compress, adaptor.valueToStore(), ptr, adaptor.mask()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -653,8 +653,8 @@ VectorType vtp = vectorType(codegen, ptr); Value pass = rewriter.create(loc, vtp, rewriter.getZeroAttr(vtp)); if (args.back().getType().isa()) - return rewriter.create(loc, vtp, ptr, args.back(), - codegen.curVecMask, pass); + return rewriter.create( + loc, vtp, ptr, args.drop_back(), args.back(), codegen.curVecMask, pass); return rewriter.create(loc, vtp, ptr, args, codegen.curVecMask, pass); } @@ -664,7 +664,7 @@ Value rhs, Value ptr, ArrayRef args) { Location loc = ptr.getLoc(); if (args.back().getType().isa()) - rewriter.create(loc, ptr, args.back(), + rewriter.create(loc, ptr, args.drop_back(), args.back(), codegen.curVecMask, rhs); else rewriter.create(loc, ptr, args, codegen.curVecMask, @@ -985,11 +985,13 @@ unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); types.push_back(indexType); + assert(codegen.pidxs[tensor][idx].getType().isa()); operands.push_back(codegen.pidxs[tensor][idx]); } } if (needsUniv) { types.push_back(indexType); + assert(codegen.loops[idx].getType().isa()); operands.push_back(codegen.loops[idx]); } Location loc = op.getLoc(); @@ -1160,6 +1162,7 @@ genTensorStore(merger, codegen, rewriter, op, lhs, rhs); return; } + assert(codegen.curVecLength == 1); // Construct iteration lattices for current loop index, with L0 at top. // Then emit initialization code for the loop sequence at this level. @@ -1239,6 +1242,7 @@ } genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); codegen.loops[idx] = Value(); + codegen.curVecLength = 1; } namespace { diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2771,14 +2771,16 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(GatherOp op) { - VectorType indicesVType = op.getIndicesVectorType(); + VectorType indVType = op.getIndexVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType resVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) + if (llvm::size(op.indices()) + 1 != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; + if (resVType.getDimSize(0) != indVType.getDimSize(0)) return op.emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); @@ -2817,14 +2819,16 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(ScatterOp op) { - VectorType indicesVType = op.getIndicesVectorType(); + VectorType indVType = op.getIndexVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getVectorType(); MemRefType memType = op.getMemRefType(); if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and valueToStore element type should match"); - if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) + if (llvm::size(op.indices()) + 1 != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; + if (valueVType.getDimSize(0) != indVType.getDimSize(0)) return op.emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected valueToStore dim to match mask dim"); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1302,28 +1302,55 @@ // ----- func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { - %0 = vector.gather %arg0[%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> + %0 = vector.gather %arg0[][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> return %0 : vector<3xf32> } // CHECK-LABEL: func @gather_op -// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> // CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32> // CHECK: return %[[G]] : vector<3xf32> // ----- +func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { + %0 = constant 3 : index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %1 : vector<4xf32> +} + +// CHECK-LABEL: func @gather_2d_op +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr> +// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32> +// CHECK: return %[[G]] : vector<4xf32> + +// ----- + func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) { - vector.scatter %arg0[%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> + vector.scatter %arg0[][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> return } // CHECK-LABEL: func @scatter_op -// CHECK: %[[P:.*]] = llvm.getelementptr {{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> +// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr> // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr> // ----- +func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { + %0 = constant 3 : index + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> + return +} + +// CHECK-LABEL: func @scatter_2d_op +// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr> +// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr> + +// ----- + func @expand_load_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> { %c0 = constant 0: index %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> into vector<11xf32> diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir --- a/mlir/test/Dialect/Linalg/sparse_vector.mlir +++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir @@ -128,9 +128,9 @@ // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> // CHECK-VEC2: } // CHECK-VEC2: return // @@ -159,9 +159,9 @@ // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[] [%[[li]]], %[[mask]], %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> // CHECK-VEC2: } // CHECK-VEC2: return // @@ -324,9 +324,9 @@ // CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> // CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> // CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> -// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> -// CHECK-VEC2: vector.scatter %{{.*}}[%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[i]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> // CHECK-VEC2: } // CHECK-VEC2: } // CHECK-VEC2: return diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -718,7 +718,7 @@ %c0 = constant 0 : index %0 = vector.maskedload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %1 = vector.gather %base[%indices], %mask, %passthru : + %1 = vector.gather %base[][%indices], %mask, %passthru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> %2 = vector.expandload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1288,16 +1288,25 @@ func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op base and result element type should match}} - %0 = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> } // ----- +func @gather_memref_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { + // expected-error@+1 {{'vector.gather' op requires 2 indices}} + %0 = vector.gather %base[][%indices], %mask, %pass_thru + : memref, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf64> +} + +// ----- + func @gather_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op result #0 must be of ranks 1, but got 'vector<2x16xf32>'}} - %0 = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<2x16xf32> } @@ -1306,7 +1315,7 @@ func @gather_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op expected result dim to match indices dim}} - %0 = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<17xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> } @@ -1315,7 +1324,7 @@ func @gather_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op expected result dim to match mask dim}} - %0 = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<17xi1>, vector<16xf32> into vector<16xf32> } @@ -1324,7 +1333,7 @@ func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf64>) { // expected-error@+1 {{'vector.gather' op expected pass_thru of same type as result type}} - %0 = vector.gather %base[%indices], %mask, %pass_thru + %0 = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf32> } @@ -1333,16 +1342,25 @@ func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op base and valueToStore element type should match}} - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> } // ----- +func @scatter_memref_mismatch(%base: memref, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf64>) { + // expected-error@+1 {{'vector.scatter' op requires 2 indices}} + vector.scatter %base[][%indices], %mask, %value + : memref, vector<16xi32>, vector<16xi1>, vector<16xf64> +} + +// ----- + func @scatter_rank_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<2x16xf32>) { // expected-error@+1 {{'vector.scatter' op operand #3 must be of ranks 1, but got 'vector<2x16xf32>'}} - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref, vector<16xi32>, vector<16xi1>, vector<2x16xf32> } @@ -1351,7 +1369,7 @@ func @scatter_dim_indices_mismatch(%base: memref, %indices: vector<17xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}} - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref, vector<17xi32>, vector<16xi1>, vector<16xf32> } @@ -1360,7 +1378,7 @@ func @scatter_dim_mask_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<17xi1>, %value: vector<16xf32>) { // expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match mask dim}} - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref, vector<16xi32>, vector<17xi1>, vector<16xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -541,25 +541,55 @@ return } +// CHECK-LABEL: @masked_load_and_store2d +func @masked_load_and_store2d(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { + %c0 = constant 0 : index + // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base[%c0, %c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.maskedstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + vector.maskedstore %base[%c0, %c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> + return +} + // CHECK-LABEL: @gather_and_scatter -func @gather_and_scatter(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { - // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.gather %base[%indices], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> - vector.scatter %base[%indices], %mask, %0 : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> +func @gather_and_scatter(%base: memref, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + // CHECK: %[[X:.*]] = vector.gather %{{.*}}[] [%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.gather %base[][%v], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.scatter %{{.*}}[] [%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + vector.scatter %base[][%v], %mask, %0 : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + return +} + +// CHECK-LABEL: @gather_and_scatter2d +func @gather_and_scatter2d(%base: memref, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index + // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.scatter %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + vector.scatter %base[%c0][%v], %mask, %0 : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> return } // CHECK-LABEL: @expand_and_compress func @expand_and_compress(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = constant 0 : index - // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + // CHECK: vector.compressstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> vector.compressstore %base[%c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } +// CHECK-LABEL: @expand_and_compress2d +func @expand_and_compress2d(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index + // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.expandload %base[%c0, %c0], %mask, %pass_thru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // CHECK: vector.compressstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> + vector.compressstore %base[%c0, %c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> + return +} + // CHECK-LABEL: @extract_insert_map func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>, %id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) { diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -86,11 +86,11 @@ // CHECK-SAME: %[[A1:.*]]: vector<16xi32>, // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> { // CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-NEXT: return %[[G]] : vector<16xf32> func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [16] : vector<16xi1> - %ld = vector.gather %base[%indices], %mask, %pass_thru + %ld = vector.gather %base[][%indices], %mask, %pass_thru : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } @@ -102,7 +102,7 @@ // CHECK-NEXT: return %[[A2]] : vector<16xf32> func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> { %mask = vector.constant_mask [0] : vector<16xi1> - %ld = vector.gather %base[%indices], %mask, %pass_thru + %ld = vector.gather %base[][%indices], %mask, %pass_thru : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> return %ld : vector<16xf32> } @@ -112,11 +112,11 @@ // CHECK-SAME: %[[A1:.*]]: vector<16xi32>, // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) { // CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1> -// CHECK-NEXT: vector.scatter %[[A0]][%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-NEXT: vector.scatter %[[A0]][] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> // CHECK-NEXT: return func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %mask = vector.constant_mask [16] : vector<16xi1> - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> return } @@ -129,7 +129,7 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) { %0 = vector.type_cast %base : memref<16xf32> to memref> %mask = vector.constant_mask [0] : vector<16xi1> - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-gather.mlir @@ -5,7 +5,7 @@ func @gather8(%base: memref, %indices: vector<8xi32>, %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> { - %g = vector.gather %base[%indices], %mask, %pass_thru + %g = vector.gather %base[][%indices], %mask, %pass_thru : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32> return %g : vector<8xf32> } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-scatter.mlir @@ -6,7 +6,7 @@ func @scatter8(%base: memref, %indices: vector<8xi32>, %mask: vector<8xi1>, %value: vector<8xf32>) { - vector.scatter %base[%indices], %mask, %value + vector.scatter %base[][%indices], %mask, %value : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-dot-matvec.mlir @@ -64,7 +64,7 @@ scf.for %i = %c0 to %cn step %c1 { %aval = load %AVAL[%i] : memref<8xvector<4xf32>> %aidx = load %AIDX[%i] : memref<8xvector<4xi32>> - %0 = vector.gather %X[%aidx], %mask, %pass + %0 = vector.gather %X[][%aidx], %mask, %pass : memref, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> %1 = vector.contract #dot_trait %aval, %0, %f0 : vector<4xf32>, vector<4xf32> into f32 store %1, %B[%i] : memref diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-sparse-saxpy-jagged-matvec.mlir @@ -57,7 +57,7 @@ %b_out = scf.for %k = %c0 to %cn step %c1 iter_args(%b_iter = %b) -> (vector<8xf32>) { %aval = load %AVAL[%k] : memref<4xvector<8xf32>> %aidx = load %AIDX[%k] : memref<4xvector<8xi32>> - %0 = vector.gather %X[%aidx], %mask, %pass + %0 = vector.gather %X[][%aidx], %mask, %pass : memref, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32> %b_new = vector.fma %aval, %0, %b_iter : vector<8xf32> scf.yield %b_new : vector<8xf32>