diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -19,10 +19,37 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" +#include using namespace mlir; namespace { +static uint64_t getFirstIntValue(ArrayAttr attr) { + return attr.getAsValueRange().begin()->getZExtValue(); +}; + +struct VectorBitcastConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); + if (!dstType) + return failure(); + + vector::BitCastOp::Adaptor adaptor(operands); + if (dstType == adaptor.source().getType()) + rewriter.replaceOp(bitcastOp, adaptor.source()); + else + rewriter.replaceOpWithNewOp(bitcastOp, dstType, + adaptor.source()); + + return success(); + } +}; + struct VectorBroadcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -49,17 +76,58 @@ LogicalResult matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (extractOp.getType().isa() || - !spirv::CompositeType::isValid(extractOp.getVectorType())) + // Only support extracting a scalar value now. + VectorType resultVectorType = extractOp.getType().dyn_cast(); + if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); + + auto dstType = getTypeConverter()->convertType(extractOp.getType()); + if (!dstType) + return failure(); + vector::ExtractOp::Adaptor adaptor(operands); - int32_t id = extractOp.position().begin()->cast().getInt(); + int32_t id = getFirstIntValue(extractOp.position()); rewriter.replaceOpWithNewOp( extractOp, adaptor.vector(), id); return success(); } }; +struct VectorExtractStridedSliceOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = getTypeConverter()->convertType(extractOp.getType()); + if (!dstType) + return failure(); + + // Extract vector<1xT> not supported yet. + if (dstType.isa()) + return failure(); + + uint64_t offset = getFirstIntValue(extractOp.offsets()); + uint64_t size = getFirstIntValue(extractOp.sizes()); + uint64_t stride = getFirstIntValue(extractOp.strides()); + if (stride != 1) + return failure(); + + Value srcVector = operands.front(); + + SmallVector indices(size); + std::iota(indices.begin(), indices.end(), offset); + + rewriter.replaceOpWithNewOp( + extractOp, dstType, srcVector, srcVector, + rewriter.getI32ArrayAttr(indices)); + + return success(); + } +}; + struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -86,7 +154,7 @@ !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); vector::InsertOp::Adaptor adaptor(operands); - int32_t id = insertOp.position().begin()->cast().getInt(); + int32_t id = getFirstIntValue(insertOp.position()); rewriter.replaceOpWithNewOp( insertOp, adaptor.source(), adaptor.dest(), id); return success(); @@ -129,13 +197,53 @@ } }; +struct VectorInsertStridedSliceOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::InsertStridedSliceOp insertOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Value srcVector = operands.front(); + Value dstVector = operands.back(); + + // Insert scalar values not supported yet. + if (srcVector.getType().isa() || + dstVector.getType().isa()) + return failure(); + + uint64_t stride = getFirstIntValue(insertOp.strides()); + if (stride != 1) + return failure(); + + uint64_t totalSize = + dstVector.getType().cast().getNumElements(); + uint64_t insertSize = + srcVector.getType().cast().getNumElements(); + uint64_t offset = getFirstIntValue(insertOp.offsets()); + + SmallVector indices(totalSize); + std::iota(indices.begin(), indices.end(), 0); + std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, + totalSize); + + rewriter.replaceOpWithNewOp( + insertOp, dstVector.getType(), dstVector, srcVector, + rewriter.getI32ArrayAttr(indices)); + + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert( - typeConverter, context); + patterns.insert(typeConverter, context); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -269,12 +269,13 @@ static Optional convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, Optional storageClass = {}) { + if (type.getRank() == 1 && type.getNumElements() == 1) + return type.getElementType(); + if (!spirv::CompositeType::isValid(type)) { - // TODO: One-element vector types can be translated into scalar - // types. Vector types with more than four elements can be translated into + // TODO: Vector types with more than four elements can be translated into // array types. - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: 1- and > 4-element unimplemented\n"); + LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); return llvm::None; } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -117,9 +117,9 @@ return } -// CHECK-LABEL: @unsupported_1elem_vector -func @unsupported_1elem_vector(%arg0: vector<1xi32>) { - // CHECK: addi +// CHECK-LABEL: @one_elem_vector +func @one_elem_vector(%arg0: vector<1xi32>) { + // CHECK: spv.IAdd %{{.+}}, %{{.+}}: i32 %0 = addi %arg0, %arg0: vector<1xi32> return } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -203,18 +203,19 @@ %arg1: vector<3xf64> ) { return } +// CHECK-LABEL: spv.func @one_element_vector +// CHECK-SAME: %{{.+}}: i32 +func @one_element_vector(%arg0: vector<1xi32>) { return } + } // end module // ----- -// Check that 1- or > 4-element vectors are not supported. +// Check that > 4-element vectors are not supported. module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { -// CHECK-NOT: spv.func @one_element_vector -func @one_element_vector(%arg0: vector<1xi32>) { return } - // CHECK-NOT: spv.func @large_vector func @large_vector(%arg0: vector<1024xi32>) { return } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -1,5 +1,21 @@ // RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + +// CHECK-LABEL: func @bitcast +// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16> +// CHECK: %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16> +// CHECK: %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32 +func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) { + %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16> + %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32> + spv.Return +} + +} // end module + +// ----- + // CHECK-LABEL: broadcast // CHECK-SAME: %[[A:.*]]: f32 // CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> @@ -12,6 +28,18 @@ // ----- +// CHECK-LABEL: func @extract +// CHECK-SAME: %[[ARG:.+]]: vector<2xf32> +// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32> +// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32> +func @extract(%arg0 : vector<2xf32>) { + %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32> + %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32 + spv.Return +} + +// ----- + // CHECK-LABEL: extract_insert // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> @@ -42,6 +70,16 @@ // ----- +// CHECK-LABEL: func @extract_strided_slice +// CHECK-SAME: %[[ARG:.+]]: vector<4xf32> +// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32> +func @extract_strided_slice(%arg0: vector<4xf32>) { + %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + spv.Return +} + +// ----- + // CHECK-LABEL: insert_element // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 @@ -60,6 +98,16 @@ // ----- +// CHECK-LABEL: func @insert_strided_slice +// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> +// CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32> +func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) { + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32> + spv.Return +} + +// ----- + // CHECK-LABEL: func @fma // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>