diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -19,7 +19,9 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include @@ -249,6 +251,69 @@ } }; +struct VectorReductionPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(reduceOp.getType()); + if (!resultType) + return failure(); + + auto srcVectorType = adaptor.getVector().getType().dyn_cast(); + if (!srcVectorType || srcVectorType.getRank() != 1) + return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); + + // Extract all elements. + int numElements = srcVectorType.getDimSize(0); + SmallVector values; + values.reserve(numElements + (adaptor.getAcc() != nullptr)); + Location loc = reduceOp.getLoc(); + for (int i = 0; i < numElements; ++i) { + values.push_back(rewriter.create( + loc, srcVectorType.getElementType(), adaptor.getVector(), + rewriter.getI32ArrayAttr({i}))); + } + if (Value acc = adaptor.getAcc()) + values.push_back(acc); + + // Reduce them. + Value result = values.front(); + for (Value next : llvm::makeArrayRef(values).drop_front()) { + switch (reduceOp.getKind()) { +#define INT_FLOAT_CASE(kind, iop, fop) \ + case vector::CombiningKind::kind: \ + if (resultType.isa()) { \ + result = rewriter.create(loc, resultType, result, next); \ + } else { \ + assert(resultType.isa()); \ + result = rewriter.create(loc, resultType, result, next); \ + } \ + break + + INT_FLOAT_CASE(ADD, IAddOp, FAddOp); + INT_FLOAT_CASE(MUL, IMulOp, FMulOp); + + case vector::CombiningKind::MINUI: + case vector::CombiningKind::MINSI: + case vector::CombiningKind::MINF: + case vector::CombiningKind::MAXUI: + case vector::CombiningKind::MAXSI: + case vector::CombiningKind::MAXF: + case vector::CombiningKind::AND: + case vector::CombiningKind::OR: + case vector::CombiningKind::XOR: + return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); + } + } + + rewriter.replaceOp(reduceOp, result); + return success(); + } +}; + class VectorSplatPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -312,6 +377,7 @@ VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern>(typeConverter, patterns.getContext()); + VectorReductionPattern, VectorInsertStridedSliceOpConvert, + VectorShuffleOpConvert, VectorSplatPattern>( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir rename from mlir/test/Conversion/VectorToSPIRV/simple.mlir rename to mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -46,7 +46,7 @@ // CHECK: return %[[R]] func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { %0 = vector.extract %arg0[0] : vector<1xf32> - return %0: f32 + return %0: f32 } // ----- @@ -56,7 +56,7 @@ // CHECK: spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32> func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> { %1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32> - return %1: vector<4xf32> + return %1: vector<4xf32> } // ----- @@ -67,7 +67,7 @@ // CHECK: return %[[R]] func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> { %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32> - return %1 : vector<1xf32> + return %1 : vector<1xf32> } // ----- @@ -84,7 +84,7 @@ // CHECK-LABEL: @extract_element_index func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 { - // CHECK: vector.extractelement + // CHECK: vector.extractelement %0 = vector.extractelement %arg0[%id : index] : vector<4xf32> return %0: f32 } @@ -93,7 +93,7 @@ // CHECK-LABEL: @extract_element_size5_vector func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 { - // CHECK: vector.extractelement + // CHECK: vector.extractelement %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32> return %0: f32 } @@ -124,7 +124,7 @@ // CHECK-LABEL: @insert_element_index func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - // CHECK: vector.insertelement + // CHECK: vector.insertelement %0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32> return %0: vector<4xf32> } @@ -133,19 +133,19 @@ // CHECK-LABEL: @insert_element_size5_vector func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> { - // CHECK: vector.insertelement + // CHECK: vector.insertelement %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32> - return %0 : vector<5xf32> + return %0 : vector<5xf32> } // ----- // CHECK-LABEL: @insert_strided_slice // CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32> -// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32> +// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32> func.func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32> - return %0 : vector<4xf32> + return %0 : vector<4xf32> } // ----- @@ -156,7 +156,7 @@ // CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32> func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> { %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> - return %1 : vector<3xf32> + return %1 : vector<3xf32> } // ----- @@ -166,7 +166,7 @@ // CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32> func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { %0 = vector.fma %a, %b, %c: vector<4xf32> - return %0 : vector<4xf32> + return %0 : vector<4xf32> } // ----- @@ -210,3 +210,36 @@ %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32> return %shuffle : vector<3x16xf32> } + +// ----- + +// CHECK-LABEL: func @reduction +// CHECK-SAME: (%[[V:.+]]: vector<4xi32>) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<4xi32> +// CHECK: %[[S3:.+]] = spv.CompositeExtract %[[V]][3 : i32] : vector<4xi32> +// CHECK: %[[ADD0:.+]] = spv.IAdd %[[S0]], %[[S1]] +// CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]] +// CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]] +// CHECK: return %[[ADD2]] +func.func @reduction(%v : vector<4xi32>) -> i32 { + %reduce = vector.reduction , %v : vector<4xi32> into i32 + return %reduce : i32 +} + +// ----- + +// CHECK-LABEL: func @reduction +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]] +// CHECK: %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]] +// CHECK: %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]] +// CHECK: return %[[ADD2]] +func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +}