diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -727,6 +727,114 @@ let hasVerifier = 1; } +def Vector_ScalableInsertOp : + Vector_Op<"scalable.insert", [Pure, + AllElementTypesMatch<["source", "dest"]>, + AllTypesMatch<["dest", "res"]>, + PredOpTrait<"position is a multiple of the source length.", + CPred< + "(getPos() % getSourceVectorType().getNumElements()) == 0" + >>]>, + Arguments<(ins VectorOfRank<[1]>:$source, + ScalableVectorOfRank<[1]>:$dest, + I64Attr:$pos)>, + Results<(outs ScalableVectorOfRank<[1]>:$res)> { + let summary = "insert subvector into scalable vector operation"; + // NOTE: This operation is designed to map to `llvm.vector.insert`, and its + // documentation should be kept aligned with LLVM IR: + // https://llvm.org/docs/LangRef.html#llvm-vector-insert-intrinsic + let description = [{ + This operations takes a rank-1 fixed-length or scalable subvector and + inserts it within the destination scalable vector starting from the + position specificed by `pos`. If the source vector is scalable, the + insertion position will be scaled by the runtime scaling factor of the + source subvector. + + The insertion position must be a multiple of the minimum size of the source + vector. For the operation to be well defined, the source vector must fit in + the destination vector from the specified position. Since the destination + vector is scalable and its runtime length is unknown, the validity of the + operation can't be verified nor guaranteed at compile time. + + Example: + + ```mlir + %2 = vector.scalable.insert %0, %1[8] : vector<4xf32> into vector<[16]xf32> + %5 = vector.scalable.insert %3, %4[0] : vector<8xf32> into vector<[4]xf32> + %8 = vector.scalable.insert %6, %7[0] : vector<[4]xf32> into vector<[8]xf32> + ``` + + Invalid example: + ```mlir + %2 = vector.scalable.insert %0, %1[5] : vector<4xf32> into vector<[16]xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getDestVectorType() { + return getDest().getType().cast(); + } + }]; +} + +def Vector_ScalableExtractOp : + Vector_Op<"scalable.extract", [Pure, + AllElementTypesMatch<["source", "res"]>, + PredOpTrait<"position is a multiple of the result length.", + CPred< + "(getPos() % getResultVectorType().getNumElements()) == 0" + >>]>, + Arguments<(ins ScalableVectorOfRank<[1]>:$source, + I64Attr:$pos)>, + Results<(outs VectorOfRank<[1]>:$res)> { + let summary = "extract subvector from scalable vector operation"; + // NOTE: This operation is designed to map to `llvm.vector.extract`, and its + // documentation should be kept aligned with LLVM IR: + // https://llvm.org/docs/LangRef.html#llvm-vector-extract-intrinsic + let description = [{ + Takes rank-1 source vector and a position `pos` within the source + vector, and extracts a subvector starting from that position. + + The extraction position must be a multiple of the minimum size of the result + vector. For the operation to be well defined, the destination vector must + fit within the source vector from the specified position. Since the source + vector is scalable and its runtime length is unknown, the validity of the + operation can't be verified nor guaranteed at compile time. + + Example: + + ```mlir + %1 = vector.scalable.extract %0[8] : vector<4xf32> from vector<[8]xf32> + %3 = vector.scalable.extract %2[0] : vector<[4]xf32> from vector<[8]xf32> + ``` + + Invalid example: + ```mlir + %1 = vector.scalable.extract %0[5] : vector<4xf32> from vector<[16]xf32> + ``` + }]; + + let assemblyFormat = [{ + $source `[` $pos `]` attr-dict `:` type($res) `from` type($source) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return getSource().getType().cast(); + } + VectorType getResultVectorType() { + return getRes().getType().cast(); + } + }]; +} + def Vector_InsertStridedSliceOp : Vector_Op<"insert_strided_slice", [Pure, PredOpTrait<"operand #0 and result have same element type", diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -579,11 +579,39 @@ == }] # allowedlength>)>]>; +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanks` list +class IsFixedVectorOfRankPred allowedRanks> : + And<[IsFixedVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedRanks` list +class IsScalableVectorOfRankPred allowedRanks> : + And<[IsScalableVectorTypePred, + Or().getRank() + == }] + # allowedlength>)>]>; + // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; +// Any fixed-length vector where the rank is from the given `allowedRanks` list +class FixedVectorOfRank allowedRanks> : Type< + IsFixedVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + +// Any scalable vector where the rank is from the given `allowedRanks` list +class ScalableVectorOfRank allowedRanks> : Type< + IsScalableVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -857,6 +857,37 @@ } }; +/// Lower vector.scalable.insert ops to LLVM vector.insert +struct VectorScalableInsertOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + vector::ScalableInsertOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + insOp, adaptor.getSource(), adaptor.getDest(), adaptor.getPos()); + return success(); + } +}; + +/// Lower vector.scalable.extract ops to LLVM vector.extract +struct VectorScalableExtractOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + vector::ScalableExtractOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + extOp, typeConverter->convertType(extOp.getResultVectorType()), + adaptor.getSource(), adaptor.getPos()); + return success(); + } +}; + /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. /// /// Example: @@ -1329,7 +1360,9 @@ vector::MaskedStoreOpAdaptor>, VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, - VectorSplatOpLowering, VectorSplatNdOpLowering>(converter); + VectorSplatOpLowering, VectorSplatNdOpLowering, + VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>( + converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2140,3 +2140,25 @@ // CHECK-NEXT: %[[SPLAT:[0-9]+]] = llvm.shufflevector %[[V]], %[[UNDEF]] [0, 0, 0, 0] // CHECK-NEXT: %[[SCALE:[0-9]+]] = arith.mulf %[[A]], %[[SPLAT]] : vector<4xf32> // CHECK-NEXT: return %[[SCALE]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @vector_scalable_insert +// CHECK-SAME: %[[SUB:.*]]: vector<4xf32>, %[[SV:.*]]: vector<[4]xf32> +func.func @vector_scalable_insert(%sub: vector<4xf32>, %dsv: vector<[4]xf32>) -> vector<[4]xf32> { + // CHECK-NEXT: %[[TMP:.*]] = llvm.intr.vector.insert %[[SUB]], %[[SV]][0] : vector<4xf32> into vector<[4]xf32> + %0 = vector.scalable.insert %sub, %dsv[0] : vector<4xf32> into vector<[4]xf32> + // CHECK-NEXT: llvm.intr.vector.insert %[[SUB]], %[[TMP]][4] : vector<4xf32> into vector<[4]xf32> + %1 = vector.scalable.insert %sub, %0[4] : vector<4xf32> into vector<[4]xf32> + return %1 : vector<[4]xf32> +} + +// ----- + +// CHECK-LABEL: @vector_scalable_extract +// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32> +func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> { + // CHECK-NEXT: %{{.*}} = llvm.intr.vector.extract %[[VEC]][0] : vector<8xf32> from vector<[4]xf32> + %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1632,3 +1632,16 @@ return } +// ----- + +func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) { + // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}} + %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32> +} + +// ----- + +func.func @vector_scalable_extract_unaligned(%vec: vector<[16]xf32>) { + // expected-error@+1 {{op failed to verify that position is a multiple of the result length.}} + %0 = vector.scalable.extract %vec[5] : vector<4xf32> from vector<[16]xf32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -853,3 +853,28 @@ return } +// CHECK-LABEL: func @vector_scalable_insert( +// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>, +// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32> +func.func @vector_scalable_insert(%sub0: vector<4xi32>, %sub1: vector<8xi32>, + %sub2: vector<[4]xi32>, %sv: vector<[8]xi32>) { + // CHECK-NEXT: vector.scalable.insert %[[SUB0]], %[[SV]][12] : vector<4xi32> into vector<[8]xi32> + %0 = vector.scalable.insert %sub0, %sv[12] : vector<4xi32> into vector<[8]xi32> + // CHECK-NEXT: vector.scalable.insert %[[SUB1]], %[[SV]][0] : vector<8xi32> into vector<[8]xi32> + %1 = vector.scalable.insert %sub1, %sv[0] : vector<8xi32> into vector<[8]xi32> + // CHECK-NEXT: vector.scalable.insert %[[SUB2]], %[[SV]][0] : vector<[4]xi32> into vector<[8]xi32> + %2 = vector.scalable.insert %sub2, %sv[0] : vector<[4]xi32> into vector<[8]xi32> + return + } + +// CHECK-LABEL: func @vector_scalable_extract( +// CHECK-SAME: %[[SV:.*]]: vector<[8]xi32> +func.func @vector_scalable_extract(%sv: vector<[8]xi32>) { + // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<16xi32> from vector<[8]xi32> + %0 = vector.scalable.extract %sv[0] : vector<16xi32> from vector<[8]xi32> + // CHECK-NEXT: vector.scalable.extract %[[SV]][0] : vector<[4]xi32> from vector<[8]xi32> + %1 = vector.scalable.extract %sv[0] : vector<[4]xi32> from vector<[8]xi32> + // CHECK-NEXT: vector.scalable.extract %[[SV]][4] : vector<4xi32> from vector<[8]xi32> + %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32> + return + }