diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -765,6 +765,122 @@ let hasVerifier = 1; } +def Vector_Insert1DOp : + Vector_Op<"insert_1d", [NoSideEffect, + AllElementTypesMatch<["source", "dest"]>, + AllTypesMatch<["dest", "res"]>, + PredOpTrait<"position is a multiple of the source length.", + CPred< + "(getPos() % getSourceVectorType().getNumElements()) == 0" + >>, + PredOpTrait<"source vector fits in destination position.", Or<[ + CPred<"getDestVectorType().isScalable()">, + CPred<"(getSourceVectorType().getNumElements() + (int64_t)getPos() - 1)" + "< getDestVectorType().getNumElements()">]>>]>, + Arguments<(ins FixedVectorOfRank<[1]>:$source, + VectorOfRank<[1]>:$dest, + I64Attr:$pos)>, + Results<(outs VectorOfRank<[1]>:$res)> { + let summary = "insert subvector into vector operation"; + let description = [{ + Takes rank-1 `n-D` source subvector, a rank-1 `m-D` destination vector, a + position `pos` within the destination vector, and inserts the elements + `(0, ..., n-1)` of the `n-D` subvector into the `m-D` vector, starting from + the specified position `pos`, and up to `(pos + n - 1)`. + + The source vector must be a fixed-length vector, but the destination can + be scalable, enabling the construction of scalable vectors from + fixed-length vectors. + + The insertion position must be a multiple of the size of the source vector. + If the destination vector is fixed-length, `(pos + n - 1)` must be smaller + than `m`. If the destination vector is scalable, `(pos + n - 1)` must be + smaller than the runtime length of the destination vector `(vscale x m)`, + which can't be verified at compile time. + + Example: + + ```mlir + %2 = vector.insert_1d %0, %1[8] : vector<4xf32> into vector<16xf32> + %5 = vector.insert_1d %3, %4[0] : vector<8xf32> into vector<[4]xf32> + ``` + + Invalid example: + ```mlir + %2 = vector.insert_1d %0, %1[5] : vector<4xf32> into vector<16xf32> + %5 = vector.insert_1d %3, %4[16] : vector<4xf32> into vector<16xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getDestVectorType() { + return getDest().getType().cast(); + } + }]; +} + +def Vector_Extract1DOp : + Vector_Op<"extract_1d", [NoSideEffect, + PredOpTrait<"position is a multiple of the result length.", + CPred< + "(getPos() % getResultVectorType().getNumElements()) == 0" + >>, + PredOpTrait<"is not accessing elements beyond the source length.", Or<[ + CPred<"getSourceVectorType().isScalable()">, + CPred<"(getResultVectorType().getNumElements() + (int64_t)getPos() - 1)" + "< getSourceVectorType().getNumElements()">]>>]>, + Arguments<(ins VectorOfRank<[1]>:$source, + I64Attr:$pos)>, + Results<(outs FixedVectorOfRank<[1]>:$res)> { + let summary = "extract subvector from vector operation"; + let description = [{ + Takes rank-1 `n-D` source vector and a position `pos` within the source + vector, and extracts the elements `(pos, ..., pos + m - 1)` from the `n-D` + source vector into a new rank-1 `m-D` fixed-length vector. `pos` must be + aligned to the length of the resulting vector, that is, it must be a + multiple of `m`. + + This operation can only extract fixed-length vectors, but the source vector + can be either fixed-length or scalable. If the source vector is + fixed-length, `(pos + m - 1)` must be smaller than `n`. If the source vector + is scalable, `(pos + m - 1)` must be smaller than the runtime length of the + source vector `(vscale x n)`, which can't be verified at compile time. + + Example: + + ```mlir + %1 = vector.extract_1d %0[8] : vector<4xf32> from vector<16xf32> + %3 = vector.extract_1d %2[0] : vector<8xf32> from vector<[4]xf32> + ``` + + Invalid example: + ```mlir + %1 = vector.extract_1d %0[5] : vector<4xf32> from vector<16xf32> + %3 = vector.insert_1d %2[16] : vector<4xf32> from vector<16xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `[` $pos `]` attr-dict `:` type($res) `from` type($source) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getResultVectorType() { + return getRes().getType().cast(); + } + }]; +} + def Vector_InsertMapOp : Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>, Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic:$ids)>, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -564,11 +564,26 @@ == }] # allowedlength>)>]>; +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanks` list +class IsFixedVectorOfRankPred allowedRanks> : + And<[IsFixedVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + + // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; +// Any fixed-length vector where the rank is from the given `allowedRanks` list +class FixedVectorOfRank allowedRanks> : Type< + IsFixedVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -737,6 +737,39 @@ } }; +/// Lower vector.insert_1d ops to LLVM vector.insert +struct VectorInsert1DOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::Insert1DOp ins1DOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + ins1DOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); + return success(); + } +}; + +/// Lower vector.extract_1d ops to LLVM vector.extract +struct VectorExtract1DOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::Extract1DOp ext1DOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + ext1DOp, typeConverter->convertType(ext1DOp.getResultVectorType()), + adaptor.getSource(), adaptor.getPos()); + return success(); + /*auto zero = rewriter.create( + splatOp.getLoc(), + typeConverter->convertType(rewriter.getIntegerType(32)), + rewriter.getZeroAttr(rewriter.getIntegerType(32)));*/ + } +}; + /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. /// /// Example: @@ -1214,7 +1247,8 @@ vector::MaskedStoreOpAdaptor>, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); + VectorSplatOpLowering, VectorSplatNdOpLowering, + VectorInsert1DOpLowering, VectorExtract1DOpLowering>(converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1814,3 +1814,27 @@ // CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32] // CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32> // CHECK-NEXT: return %[[SCALE]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @vector_insert_1d +// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[FL:.*]]: vector<8xf32>, %[[SV:.*]]: vector<[4]xf32> +func.func @vector_insert_1d(%sub: vector<4xf32>, %dfl: vector<8xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK-NEXT: %[[BFL:.*]] = llvm.intr.experimental.vector.insert %[[SUB]], %[[FL]][4] : vector<4xf32> into vector<8xf32> + %0 = vector.insert_1d %sub, %dfl[4] : vector<4xf32> into vector<8xf32> + // CHECK-NEXT: llvm.intr.experimental.vector.insert %[[BFL]], %[[SV]][0] : vector<8xf32> into vector<[4]xf32> + %1 = vector.insert_1d %0, %dsv[0] : vector<8xf32> into vector<[4]xf32> + return %1 : vector<[4]xf32> +} + +// ----- + +// CHECK-LABEL: @vector_extract_1d +// CHECK-SAME: %[[IN:.*]]: vector<[4]xf32> +func.func @vector_extract_1d(%in: vector<[4]xf32>) -> vector<4xf32> { + // CHECK-NEXT: %[[BFL:.*]] = llvm.intr.experimental.vector.extract %[[IN]][0] : vector<8xf32> from vector<[4]xf32> + %0 = vector.extract_1d %in[0] : vector<8xf32> from vector<[4]xf32> + // CHECK-NEXT: llvm.intr.experimental.vector.extract %[[BFL]][4] : vector<4xf32> from vector<8xf32> + %1 = vector.extract_1d %0[4] : vector<4xf32> from vector<8xf32> + return %1 : vector<4xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1612,3 +1612,33 @@ } return } + +// ----- + +func.func @vector_insert_1d_bad_position(%subv: vector<4xi32>, %vec: vector<16xi32>) { + // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}} + %0 = vector.insert_1d %subv, %vec[2] : vector<4xi32> into vector<16xi32> +} + +// ----- + +func.func @vector_insert_1d_bad_fit(%subv: vector<4xi32>, %vec: vector<16xi32>) { + // expected-error@+1 {{op failed to verify that source vector fits in destination position.}} + %0 = vector.insert_1d %subv, %vec[16] : vector<4xi32> into vector<16xi32> +} + +// ----- + +func.func @vector_extract_1d_oob(%v: vector<24xi8>) -> vector<16xi8> { + // expected-error@+1 {{op failed to verify that is not accessing elements beyond the source length.}} + %0 = vector.extract_1d %v[16] : vector<16xi8> from vector<24xi8> + return %0 : vector<16xi8> +} + +// ----- + +func.func @extract_unaligned(%v: vector<16xf32>) -> vector<4xf32> { + // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}} + %0 = vector.extract_1d %v[5] : vector<4xf32> from vector<16xf32> + return %0 : vector<4xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -771,4 +771,32 @@ return %2 : vector<4xi32> } - +// CHECK-LABEL: func @vector_insert_1d( +// CHECK-SAME: %[[SUBS:.*]]: vector<4xi32>, %[[SUBL:.*]]: vector<8xi32>, +// CHECK-SAME: %[[FV:.*]]: vector<16xi32>, %[[SV:.*]]: vector<[4]xi32> +func.func @vector_insert_1d(%subs: vector<4xi32>, %subl: vector<8xi32>, + %flv: vector<16xi32>, %sv: vector<[4]xi32>) { + // CHECK-NEXT: vector.insert_1d %[[SUBL]], %[[FV]][0] : vector<8xi32> into vector<16xi32> + %0 = vector.insert_1d %subl, %flv[0] : vector<8xi32> into vector<16xi32> + // CHECK-NEXT: vector.insert_1d %[[SUBS]], %[[FV]][12] : vector<4xi32> into vector<16xi32> + %1 = vector.insert_1d %subs, %flv[12] : vector<4xi32> into vector<16xi32> + // CHECK-NEXT: vector.insert_1d %[[SUBL]], %[[SV]][0] : vector<8xi32> into vector<[4]xi32> + %2 = vector.insert_1d %subl, %sv[0] : vector<8xi32> into vector<[4]xi32> + // CHECK-NEXT: vector.insert_1d %[[SUBS]], %[[SV]][0] : vector<4xi32> into vector<[4]xi32> + %3 = vector.insert_1d %subs, %sv[0] : vector<4xi32> into vector<[4]xi32> + return + } + +// CHECK-LABEL: func @vector_extract_1d( +// CHECK-SAME: %[[FV:.*]]: vector<16xi32>, %[[SV:.*]]: vector<[4]xi32> +func.func @vector_extract_1d(%flv: vector<16xi32>, %sv: vector<[4]xi32>) { + // CHECK-NEXT: vector.extract_1d %[[FV]][0] : vector<8xi32> from vector<16xi32> + %0 = vector.extract_1d %flv[0] : vector<8xi32> from vector<16xi32> + // CHECK-NEXT: vector.extract_1d %[[FV]][12] : vector<4xi32> from vector<16xi32> + %1 = vector.extract_1d %flv[12] : vector<4xi32> from vector<16xi32> + // CHECK-NEXT: vector.extract_1d %[[SV]][0] : vector<8xi32> from vector<[4]xi32> + %2 = vector.extract_1d %sv[0] : vector<8xi32> from vector<[4]xi32> + // CHECK-NEXT: vector.extract_1d %[[SV]][4] : vector<4xi32> from vector<[4]xi32> + %3 = vector.extract_1d %sv[4] : vector<4xi32> from vector<[4]xi32> + return + }