diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -6,28 +6,93 @@ // //===----------------------------------------------------------------------===// // -// This file implements patterns to convert Standard Ops to the SPIR-V dialect. +// This file implements patterns to convert standard ops to SPIR-V ops. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" using namespace mlir; //===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Returns true if the given `type` is a boolean scalar or vector type. +static bool isBoolScalarOrVector(Type type) { + if (type.isInteger(1)) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isInteger(1); + return false; +} + +//===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// +// Note that DRR cannot be used for the patterns in this file: we may need to +// convert type along the way, which requires ConversionPattern. DRR generates +// normal RewritePattern. + namespace { -/// Convert composite constant operation to SPIR-V dialect. -// TODO(denis0x0D) : move to DRR. -class ConstantCompositeOpConversion final : public SPIRVOpLowering { +/// Converts binary standard operations to SPIR-V operations. +template +class BinaryOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(StdOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 2); + auto dstType = this->typeConverter.convertType(operation.getType()); + if (!dstType) + return failure(); + rewriter.template replaceOpWithNewOp(operation, dstType, operands, + ArrayRef()); + return success(); + } +}; + +/// Converts bitwise standard operations to SPIR-V operations. This is a special +/// pattern other than the BinaryOpPatternPattern because if the operands are +/// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For +/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. +template +class BitwiseOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(StdOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 2); + auto dstType = + this->typeConverter.convertType(operation.getResult().getType()); + if (!dstType) + return failure(); + if (isBoolScalarOrVector(operands.front().getType())) { + rewriter.template replaceOpWithNewOp( + operation, dstType, operands, ArrayRef()); + } else { + rewriter.template replaceOpWithNewOp( + operation, dstType, operands, ArrayRef()); + } + return success(); + } +}; + +/// Converts composite std.constant operation to spv.constant. +class ConstantCompositeOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -36,12 +101,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert constant operation with IndexType return to SPIR-V constant -/// operation. Since IndexType is not used within SPIR-V dialect, this needs -/// special handling to make sure the result type and the type of the value -/// attribute are consistent. -// TODO(ravishankarm) : This should be moved into DRR. -class ConstantIndexOpConversion final : public SPIRVOpLowering { +/// Converts scalar std.constant operation to spv.constant. +class ConstantScalarOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -50,8 +111,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert floating-point comparison operations to SPIR-V dialect. -class CmpFOpConversion final : public SPIRVOpLowering { +/// Converts floating-point comparison operations to SPIR-V ops. +class CmpFOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -60,8 +121,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert compare operation to SPIR-V dialect. -class CmpIOpConversion final : public SPIRVOpLowering { +/// Converts integer compare operation to SPIR-V ops. +class CmpIOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -70,33 +131,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert integer binary operations to SPIR-V operations. Cannot use -/// tablegen for this. If the integer operation is on variables of IndexType, -/// the type of the return value of the replacement operation differs from -/// that of the replaced operation. This is not handled in tablegen-based -/// pattern specification. -// TODO(ravishankarm) : This should be moved into DRR. -template -class IntegerOpConversion final : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; - - LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto resultType = - this->typeConverter.convertType(operation.getResult().getType()); - rewriter.template replaceOpWithNewOp( - operation, resultType, operands, ArrayRef()); - return success(); - } -}; - -/// Convert load -> spv.LoadOp. The operands of the replaced operation are of -/// IndexType while that of the replacement operation are of type i32. This is -/// not supported in tablegen based pattern specification. -// TODO(ravishankarm) : This should be moved into DRR. -class LoadOpConversion final : public SPIRVOpLowering { +/// Converts std.load to spv.Load. +class LoadOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -105,9 +141,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert return -> spv.Return. -// TODO(ravishankarm) : This should be moved into DRR. -class ReturnOpConversion final : public SPIRVOpLowering { +/// Converts std.return to spv.Return. +class ReturnOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -116,9 +151,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert select -> spv.Select -// TODO(ravishankarm) : This should be moved into DRR. -class SelectOpConversion final : public SPIRVOpLowering { +/// Converts std.select to spv.Select. +class SelectOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult @@ -126,11 +160,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Convert store -> spv.StoreOp. The operands of the replaced operation are -/// of IndexType while that of the replacement operation are of type i32. This -/// is not supported in tablegen based pattern specification. -// TODO(ravishankarm) : This should be moved into DRR. -class StoreOpConversion final : public SPIRVOpLowering { +/// Converts std.store to spv.Store. +class StoreOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -139,13 +170,47 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts type-casting standard operations to SPIR-V operations. +template +class TypeCastingOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(StdOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + auto dstType = + this->typeConverter.convertType(operation.getResult().getType()); + if (dstType == operands.front().getType()) { + // Due to type conversion, we are seeing the same source and target type. + // Then we can just erase this operation by forwarding its operand. + rewriter.replaceOp(operation, operands.front()); + } else { + rewriter.template replaceOpWithNewOp( + operation, dstType, operands, ArrayRef()); + } + return success(); + } +}; + +/// Converts std.xor to SPIR-V operations. +class XOrOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// -LogicalResult ConstantCompositeOpConversion::matchAndRewrite( +LogicalResult ConstantCompositeOpPattern::matchAndRewrite( ConstantOp constCompositeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto compositeType = @@ -175,10 +240,10 @@ } //===----------------------------------------------------------------------===// -// ConstantOp with index type. +// ConstantOp with scalar type. //===----------------------------------------------------------------------===// -LogicalResult ConstantIndexOpConversion::matchAndRewrite( +LogicalResult ConstantScalarOpPattern::matchAndRewrite( ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!constIndexOp.getResult().getType().isa()) { @@ -213,8 +278,8 @@ //===----------------------------------------------------------------------===// LogicalResult -CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { CmpFOpOperandAdaptor cmpFOpOperands(operands); switch (cmpFOp.getPredicate()) { @@ -253,8 +318,8 @@ //===----------------------------------------------------------------------===// LogicalResult -CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); switch (cmpIOp.getPredicate()) { @@ -286,8 +351,8 @@ //===----------------------------------------------------------------------===// LogicalResult -LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); auto loadPtr = spirv::getElementPtr( typeConverter, loadOp.memref().getType().cast(), @@ -301,8 +366,8 @@ //===----------------------------------------------------------------------===// LogicalResult -ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { return failure(); } @@ -315,8 +380,8 @@ //===----------------------------------------------------------------------===// LogicalResult -SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), @@ -329,8 +394,8 @@ //===----------------------------------------------------------------------===// LogicalResult -StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { +StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); auto storePtr = spirv::getElementPtr( typeConverter, storeOp.memref().getType().cast(), @@ -341,6 +406,31 @@ return success(); } +//===----------------------------------------------------------------------===// +// XorOp +//===----------------------------------------------------------------------===// + +LogicalResult +XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = typeConverter.convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(xorOp, dstType, operands, + ArrayRef()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + namespace { /// Import the Standard Ops to SPIR-V Patterns. #include "StandardToSPIRV.cpp.inc" @@ -352,14 +442,29 @@ OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert, - IntegerOpConversion, - IntegerOpConversion, - IntegerOpConversion, - IntegerOpConversion, LoadOpConversion, - ReturnOpConversion, SelectOpConversion, StoreOpConversion>( + patterns.insert< + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BinaryOpPattern, + BitwiseOpPattern, + BitwiseOpPattern, + ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, + CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern, + StoreOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, + TypeCastingOpPattern, XOrOpPattern>( context, typeConverter); } } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -16,34 +16,6 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" include "mlir/Dialect/SPIRV/SPIRVOps.td" -class BinaryOpPattern : - Pat<(src SPV_ScalarOrVectorOf:$l, SPV_ScalarOrVectorOf:$r), - (tgt $l, $r)>; - -class UnaryOpPattern : - Pat<(src type:$input), - (tgt $input)>; - -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; -def : BinaryOpPattern; - -def : UnaryOpPattern; -def : UnaryOpPattern; -def : UnaryOpPattern; - // Constant Op // TODO(ravishankarm): Handle lowering other constant types. def : Pat<(ConstantOp:$result $valueAttr), diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -1,112 +1,142 @@ -// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s + +//===----------------------------------------------------------------------===// +// std arithmetic ops +//===----------------------------------------------------------------------===// module attributes { spv.target_env = #spv.target_env< - #spv.vce, + #spv.vce, {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { -//===----------------------------------------------------------------------===// -// std binary arithmetic ops -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @add_sub -func @add_sub(%arg0 : i32, %arg1 : i32) { - // CHECK: spv.IAdd - %0 = addi %arg0, %arg1 : i32 - // CHECK: spv.ISub - %1 = subi %arg0, %arg1 : i32 +// Check integer operation conversions. +// CHECK-LABEL: @int32_scalar +func @int32_scalar(%lhs: i32, %rhs: i32) { + // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32 + %0 = addi %lhs, %rhs: i32 + // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32 + %1 = subi %lhs, %rhs: i32 + // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32 + %2 = muli %lhs, %rhs: i32 + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32 + %3 = divi_signed %lhs, %rhs: i32 + // CHECK: spv.SRem %{{.*}}, %{{.*}}: i32 + %4 = remi_signed %lhs, %rhs: i32 + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32 + %5 = divi_unsigned %lhs, %rhs: i32 + // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 + %6 = remi_unsigned %lhs, %rhs: i32 return } -// CHECK-LABEL: @fadd_scalar -func @fadd_scalar(%arg: f32) { - // CHECK: spv.FAdd - %0 = addf %arg, %arg : f32 +// Check float operation conversions. +// CHECK-LABEL: @float32_scalar +func @float32_scalar(%lhs: f32, %rhs: f32) { + // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 + %0 = addf %lhs, %rhs: f32 + // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32 + %1 = subf %lhs, %rhs: f32 + // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 + %2 = mulf %lhs, %rhs: f32 + // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32 + %3 = divf %lhs, %rhs: f32 + // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 + %4 = remf %lhs, %rhs: f32 return } -// CHECK-LABEL: @fdiv_scalar -func @fdiv_scalar(%arg: f32) { - // CHECK: spv.FDiv - %0 = divf %arg, %arg : f32 +// Check int vector types. +// CHECK-LABEL: @int_vector234 +func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8> + %0 = divi_signed %arg0, %arg0: vector<2xi8> + // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi16> + %1 = remi_signed %arg1, %arg1: vector<3xi16> + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64> + %2 = divi_unsigned %arg2, %arg2: vector<4xi64> return } -// CHECK-LABEL: @fmul_scalar -func @fmul_scalar(%arg: f32) { - // CHECK: spv.FMul - %0 = mulf %arg, %arg : f32 +// Check float vector types. +// CHECK-LABEL: @float_vector234 +func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) { + // CHECK: spv.FAdd %{{.*}}, %{{.*}}: vector<2xf16> + %0 = addf %arg0, %arg0: vector<2xf16> + // CHECK: spv.FMul %{{.*}}, %{{.*}}: vector<3xf64> + %1 = mulf %arg1, %arg1: vector<3xf64> return } -// CHECK-LABEL: @fmul_vector2 -func @fmul_vector2(%arg: vector<2xf32>) { - // CHECK: spv.FMul - %0 = mulf %arg, %arg : vector<2xf32> +// CHECK-LABEL: @unsupported_1elem_vector +func @unsupported_1elem_vector(%arg0: vector<1xi32>) { + // CHECK: addi + %0 = addi %arg0, %arg0: vector<1xi32> return } -// CHECK-LABEL: @fmul_vector3 -func @fmul_vector3(%arg: vector<3xf32>) { - // CHECK: spv.FMul - %0 = mulf %arg, %arg : vector<3xf32> +// CHECK-LABEL: @unsupported_5elem_vector +func @unsupported_5elem_vector(%arg0: vector<5xi32>) { + // CHECK: subi + %1 = subi %arg0, %arg0: vector<5xi32> return } -// CHECK-LABEL: @fmul_vector4 -func @fmul_vector4(%arg: vector<4xf32>) { - // CHECK: spv.FMul - %0 = mulf %arg, %arg : vector<4xf32> +// CHECK-LABEL: @unsupported_2x2elem_vector +func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { + // CHECK: muli + %2 = muli %arg0, %arg0: vector<2x2xi32> return } -// CHECK-LABEL: @fmul_vector5 -func @fmul_vector5(%arg: vector<5xf32>) { - // Vector length of only 2, 3, and 4 is valid for SPIR-V. - // CHECK: mulf - %0 = mulf %arg, %arg : vector<5xf32> - return -} +} // end module -// TODO(antiagainst): enable this once we support converting binary ops -// needing type conversion. -// XXXXX-LABEL: @fmul_tensor -//func @fmul_tensor(%arg: tensor<4xf32>) { - // For tensors mulf cannot be lowered directly to spv.FMul. - // XXXXX: mulf - //%0 = mulf %arg, %arg : tensor<4xf32> - //return -//} - -// CHECK-LABEL: @frem_scalar -func @frem_scalar(%arg: f32) { - // CHECK: spv.FRem - %0 = remf %arg, %arg : f32 - return -} +// ----- + +// Check that types are converted to 32-bit when no special capabilities. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { -// CHECK-LABEL: @fsub_scalar -func @fsub_scalar(%arg: f32) { - // CHECK: spv.FSub - %0 = subf %arg, %arg : f32 +// CHECK-LABEL: @int_vector234 +func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) { + // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> + %0 = divi_signed %arg0, %arg0: vector<2xi8> + // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32> + %1 = remi_signed %arg1, %arg1: vector<3xi16> + // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32> + %2 = divi_unsigned %arg2, %arg2: vector<4xi64> return } -// CHECK-LABEL: @div_rem -func @div_rem(%arg0 : i32, %arg1 : i32) { - // CHECK: spv.SDiv - %0 = divi_signed %arg0, %arg1 : i32 - // CHECK: spv.SMod - %1 = remi_signed %arg0, %arg1 : i32 +// CHECK-LABEL: @float_scalar +func @float_scalar(%arg0: f16, %arg1: f64) { + // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32 + %0 = addf %arg0, %arg0: f16 + // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32 + %1 = mulf %arg1, %arg1: f64 return } +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std bit ops //===----------------------------------------------------------------------===// +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + // CHECK-LABEL: @bitwise_scalar func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { // CHECK: spv.BitwiseAnd @@ -129,6 +159,24 @@ return } +// CHECK-LABEL: @logical_scalar +func @logical_scalar(%arg0 : i1, %arg1 : i1) { + // CHECK: spv.LogicalAnd + %0 = and %arg0, %arg1 : i1 + // CHECK: spv.LogicalOr + %1 = or %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @logical_vector +func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { + // CHECK: spv.LogicalAnd + %0 = and %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalOr + %1 = or %arg0, %arg1 : vector<4xi1> + return +} + // CHECK-LABEL: @shift_scalar func @shift_scalar(%arg0 : i32, %arg1 : i32) { // CHECK: spv.ShiftLeftLogical @@ -213,10 +261,21 @@ return } +} // end module + +// ----- + //===----------------------------------------------------------------------===// // std.constant //===----------------------------------------------------------------------===// +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + // CHECK-LABEL: @constant func @constant() { // CHECK: spv.constant true @@ -244,50 +303,126 @@ return } +} // end module + +// ----- + //===----------------------------------------------------------------------===// -// std logical binary operations +// std cast ops //===----------------------------------------------------------------------===// -// CHECK-LABEL: @logical_scalar -func @logical_scalar(%arg0 : i1, %arg1 : i1) { - // CHECK: spv.LogicalAnd - %0 = and %arg0, %arg1 : i1 - // CHECK: spv.LogicalOr - %1 = or %arg0, %arg1 : i1 - return +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @fpext1 +func @fpext1(%arg0: f16) -> f64 { + // CHECK: spv.FConvert %{{.*}} : f16 to f64 + %0 = std.fpext %arg0 : f16 to f64 + return %0 : f64 } -// CHECK-LABEL: @logical_vector -func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { - // CHECK: spv.LogicalAnd - %0 = and %arg0, %arg1 : vector<4xi1> - // CHECK: spv.LogicalOr - %1 = or %arg0, %arg1 : vector<4xi1> - return +// CHECK-LABEL: @fpext2 +func @fpext2(%arg0 : f32) -> f64 { + // CHECK: spv.FConvert %{{.*}} : f32 to f64 + %0 = std.fpext %arg0 : f32 to f64 + return %0 : f64 } -//===----------------------------------------------------------------------===// -// std.fpext -//===----------------------------------------------------------------------===// +// CHECK-LABEL: @fptrunc1 +func @fptrunc1(%arg0 : f64) -> f16 { + // CHECK: spv.FConvert %{{.*}} : f64 to f16 + %0 = std.fptrunc %arg0 : f64 to f16 + return %0 : f16 +} + +// CHECK-LABEL: @fptrunc2 +func @fptrunc2(%arg0: f32) -> f16 { + // CHECK: spv.FConvert %{{.*}} : f32 to f16 + %0 = std.fptrunc %arg0 : f32 to f16 + return %0 : f16 +} + +// CHECK-LABEL: @sitofp1 +func @sitofp1(%arg0 : i32) -> f32 { + // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32 + %0 = std.sitofp %arg0 : i32 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @sitofp2 +func @sitofp2(%arg0 : i64) -> f64 { + // CHECK: spv.ConvertSToF %{{.*}} : i64 to f64 + %0 = std.sitofp %arg0 : i64 to f64 + return %0 : f64 +} + +} // end module + +// ----- + +// Checks that cast types will be adjusted when no special capabilities for +// non-32-bit scalar types. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { -// CHECK-LABEL: @fpext -func @fpext(%arg0 : f32) { - // CHECK: spv.FConvert +// CHECK-LABEL: @fpext1 +// CHECK-SAME: %[[ARG:.*]]: f32 +func @fpext1(%arg0: f16) { + // CHECK-NEXT: "use"(%[[ARG]]) + %0 = std.fpext %arg0 : f16 to f64 + "use"(%0) : (f64) -> () +} + +// CHECK-LABEL: @fpext2 +// CHECK-SAME: %[[ARG:.*]]: f32 +func @fpext2(%arg0 : f32) { + // CHECK-NEXT: "use"(%[[ARG]]) %0 = std.fpext %arg0 : f32 to f64 - return + "use"(%0) : (f64) -> () } -//===----------------------------------------------------------------------===// -// std.fptrunc -//===----------------------------------------------------------------------===// +// CHECK-LABEL: @fptrunc1 +// CHECK-SAME: %[[ARG:.*]]: f32 +func @fptrunc1(%arg0 : f64) { + // CHECK-NEXT: "use"(%[[ARG]]) + %0 = std.fptrunc %arg0 : f64 to f16 + "use"(%0) : (f16) -> () +} -// CHECK-LABEL: @fptrunc -func @fptrunc(%arg0 : f64) { - // CHECK: spv.FConvert - %0 = std.fptrunc %arg0 : f64 to f32 - return +// CHECK-LABEL: @fptrunc2 +// CHECK-SAME: %[[ARG:.*]]: f32 +func @fptrunc2(%arg0: f32) { + // CHECK-NEXT: "use"(%[[ARG]]) + %0 = std.fptrunc %arg0 : f32 to f16 + "use"(%0) : (f16) -> () +} + +// CHECK-LABEL: @sitofp +func @sitofp(%arg0 : i64) { + // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32 + %0 = std.sitofp %arg0 : i64 to f64 + "use"(%0) : (f64) -> () } +} // end module + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + //===----------------------------------------------------------------------===// // std.select //===----------------------------------------------------------------------===// @@ -301,41 +436,9 @@ } //===----------------------------------------------------------------------===// -// std.sitofp +// std load/store ops //===----------------------------------------------------------------------===// -// CHECK-LABEL: @sitofp -func @sitofp(%arg0 : i32) { - // CHECK: spv.ConvertSToF - %0 = std.sitofp %arg0 : i32 to f32 - return -} - -//===----------------------------------------------------------------------===// -// memref type -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) -func @memref_type(%arg0: memref<3xi1>) { - return -} - -// CHECK-LABEL: func @memref_mem_space -// CHECK-SAME: StorageBuffer -// CHECK-SAME: Uniform -// CHECK-SAME: Workgroup -// CHECK-SAME: PushConstant -// CHECK-SAME: Private -// CHECK-SAME: Function -func @memref_mem_space( - %arg0: memref<4xf32, 0>, - %arg1: memref<4xf32, 4>, - %arg2: memref<4xf32, 3>, - %arg3: memref<4xf32, 7>, - %arg4: memref<4xf32, 5>, - %arg5: memref<4xf32, 6> -) { return } - // CHECK-LABEL: @load_store_zero_rank_float // CHECK: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, // CHECK: [[ARG1:%.*]]: !spv.ptr [0]>, StorageBuffer>) diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -255,6 +255,51 @@ // MemRef types //===----------------------------------------------------------------------===// +// Check memory spaces. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @memref_mem_space +// CHECK-SAME: StorageBuffer +// CHECK-SAME: Uniform +// CHECK-SAME: Workgroup +// CHECK-SAME: PushConstant +// CHECK-SAME: Private +// CHECK-SAME: Function +func @memref_mem_space( + %arg0: memref<4xf32, 0>, + %arg1: memref<4xf32, 4>, + %arg2: memref<4xf32, 3>, + %arg3: memref<4xf32, 7>, + %arg4: memref<4xf32, 5>, + %arg5: memref<4xf32, 6> +) { return } + +} // end module + +// ----- + +// Check that boolean memref is not supported at the moment. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) +func @memref_type(%arg0: memref<3xi1>) { + return +} + +} // end module + +// ----- + // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied.