diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -20,6 +20,18 @@ class MLIRContext; class ModuleOp; +template +class SPIRVToLLVMConversion : public OpConversionPattern { +public: + SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), + typeConverter(typeConverter) {} + +protected: + LLVMTypeConverter &typeConverter; +}; + /// Populates the given list with patterns that convert from SPIR-V to LLVM. void populateSPIRVToLLVMConversionPatterns(MLIRContext *context, LLVMTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -24,30 +24,51 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + namespace { -class BitwiseAndOpConversion : public ConvertToLLVMPattern { +/// Converts SPIR-V operations that have straightforward LLVM equivalent +/// into LLVM dialect operations. +template +class DirectConversionPattern : public SPIRVToLLVMConversion { public: - explicit BitwiseAndOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(spirv::BitwiseAndOp::getOperationName(), context, - typeConverter) {} + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto bitwiseAndOp = cast(op); - auto dstType = typeConverter.convertType(bitwiseAndOp.getType()); + auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(bitwiseAndOp, dstType, operands); + rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } }; } // namespace +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(context, typeConverter); + patterns.insert, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern>( + context, typeConverter); } diff --git a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir @@ -0,0 +1,177 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.IAdd +//===----------------------------------------------------------------------===// + +func @iadd_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.IAdd %arg0, %arg1 : i32 + return +} + +func @iadd_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) { + // CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : !llvm<"<4 x i64>"> + %0 = spv.IAdd %arg0, %arg1 : vector<4xi64> + return +} + +//===----------------------------------------------------------------------===// +// spv.ISub +//===----------------------------------------------------------------------===// + +func @isub_scalar(%arg0: i8, %arg1: i8) { + // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm.i8 + %0 = spv.ISub %arg0, %arg1 : i8 + return +} + +func @isub_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) { + // CHECK: %{{.*}} = llvm.sub %{{.*}}, %{{.*}} : !llvm<"<2 x i16>"> + %0 = spv.ISub %arg0, %arg1 : vector<2xi16> + return +} + +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +func @imul_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.IMul %arg0, %arg1 : i32 + return +} + +func @imul_vector(%arg0: vector<3xi32>, %arg1: vector<3xi32>) { + // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} : !llvm<"<3 x i32>"> + %0 = spv.IMul %arg0, %arg1 : vector<3xi32> + return +} + +//===----------------------------------------------------------------------===// +// spv.FAdd +//===----------------------------------------------------------------------===// + +func @fadd_scalar(%arg0: f16, %arg1: f16) { + // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm.half + %0 = spv.FAdd %arg0, %arg1 : f16 + return +} + +func @fadd_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>) { + // CHECK: %{{.*}} = llvm.fadd %{{.*}}, %{{.*}} : !llvm<"<4 x float>"> + %0 = spv.FAdd %arg0, %arg1 : vector<4xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.FSub +//===----------------------------------------------------------------------===// + +func @fsub_scalar(%arg0: f32, %arg1: f32) { + // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm.float + %0 = spv.FSub %arg0, %arg1 : f32 + return +} + +func @fsub_vector(%arg0: vector<2xf32>, %arg1: vector<2xf32>) { + // CHECK: %{{.*}} = llvm.fsub %{{.*}}, %{{.*}} : !llvm<"<2 x float>"> + %0 = spv.FSub %arg0, %arg1 : vector<2xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.FDiv +//===----------------------------------------------------------------------===// + +func @fdiv_scalar(%arg0: f32, %arg1: f32) { + // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm.float + %0 = spv.FDiv %arg0, %arg1 : f32 + return +} + +func @fdiv_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) { + // CHECK: %{{.*}} = llvm.fdiv %{{.*}}, %{{.*}} : !llvm<"<3 x double>"> + %0 = spv.FDiv %arg0, %arg1 : vector<3xf64> + return +} + +//===----------------------------------------------------------------------===// +// spv.FRem +//===----------------------------------------------------------------------===// + +func @frem_scalar(%arg0: f32, %arg1: f32) { + // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm.float + %0 = spv.FRem %arg0, %arg1 : f32 + return +} + +func @frem_vector(%arg0: vector<3xf64>, %arg1: vector<3xf64>) { + // CHECK: %{{.*}} = llvm.frem %{{.*}}, %{{.*}} : !llvm<"<3 x double>"> + %0 = spv.FRem %arg0, %arg1 : vector<3xf64> + return +} + +//===----------------------------------------------------------------------===// +// spv.FNegate +//===----------------------------------------------------------------------===// + +func @fneg_scalar(%arg: f64) { + // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm.double + %0 = spv.FNegate %arg : f64 + return +} + +func @fneg_vector(%arg: vector<2xf32>) { + // CHECK: %{{.*}} = llvm.fneg %{{.*}} : !llvm<"<2 x float>"> + %0 = spv.FNegate %arg : vector<2xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.UDiv +//===----------------------------------------------------------------------===// + +func @udiv_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.UDiv %arg0, %arg1 : i32 + return +} + +func @udiv_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) { + // CHECK: %{{.*}} = llvm.udiv %{{.*}}, %{{.*}} : !llvm<"<3 x i64>"> + %0 = spv.UDiv %arg0, %arg1 : vector<3xi64> + return +} + +//===----------------------------------------------------------------------===// +// spv.SDiv +//===----------------------------------------------------------------------===// + +func @sdiv_scalar(%arg0: i16, %arg1: i16) { + // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm.i16 + %0 = spv.SDiv %arg0, %arg1 : i16 + return +} + +func @sdiv_vector(%arg0: vector<2xi64>, %arg1: vector<2xi64>) { + // CHECK: %{{.*}} = llvm.sdiv %{{.*}}, %{{.*}} : !llvm<"<2 x i64>"> + %0 = spv.SDiv %arg0, %arg1 : vector<2xi64> + return +} + +//===----------------------------------------------------------------------===// +// spv.SRem +//===----------------------------------------------------------------------===// + +func @srem_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.SRem %arg0, %arg1 : i32 + return +} + +func @srem_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) { + // CHECK: %{{.*}} = llvm.srem %{{.*}}, %{{.*}} : !llvm<"<4 x i32>"> + %0 = spv.SRem %arg0, %arg1 : vector<4xi32> + return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.BitwiseAnd +//===----------------------------------------------------------------------===// + +func @bitwise_and_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + return +} + +func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>"> + %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64> + return +} + +//===----------------------------------------------------------------------===// +// spv.BitwiseOr +//===----------------------------------------------------------------------===// + +func @bitwise_or_scalar(%arg0: i64, %arg1: i64) { + // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i64 + %0 = spv.BitwiseOr %arg0, %arg1 : i64 + return +} + +func @bitwise_or_vector(%arg0: vector<3xi8>, %arg1: vector<3xi8>) { + // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<3 x i8>"> + %0 = spv.BitwiseOr %arg0, %arg1 : vector<3xi8> + return +} + +//===----------------------------------------------------------------------===// +// spv.BitwiseXor +//===----------------------------------------------------------------------===// + +func @bitwise_xor_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.BitwiseXor %arg0, %arg1 : i32 + return +} + +func @bitwise_xor_vector(%arg0: vector<2xi16>, %arg1: vector<2xi16>) { + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %{{.*}} : !llvm<"<2 x i16>"> + %0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16> + return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir deleted file mode 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/convert-to-llvm.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s - -func @bitwise_and_scalar(%arg0: i32, %arg1: i32) { - // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i32 - %0 = spv.BitwiseAnd %arg0, %arg1 : i32 - return -} - -func @bitwise_and_vector(%arg0: vector<4xi64>, %arg1: vector<4xi64>) { - // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i64>"> - %0 = spv.BitwiseAnd %arg0, %arg1 : vector<4xi64> - return -}