diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -1221,4 +1221,20 @@ let hasVerifier = 0; } +def SPV_GLSLFindUMsbOp : SPV_GLSLUnaryArithmeticOp<"FindUMsb", 75, SPV_Int32> { + let summary = "Unsigned-integer most-significant bit"; + + let description = [{ + Results in the bit number of the most-significant 1-bit in the binary + representation of Value. If Value is 0, the result is -1. + + Result Type and the type of Value must both be integer scalar or + integer vector types. Result Type and operand types must have the + same number of components with the same component width. Results are + computed per component. + + This instruction is currently limited to 32-bit width components. + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -16,12 +16,35 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "math-to-spirv-pattern" using namespace mlir; +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the +/// given type is not a 32-bit scalar/vector type. +static Value getScalarOrVectorI32Constant(Type type, int value, + OpBuilder &builder, Location loc) { + if (auto vectorType = type.dyn_cast()) { + if (!vectorType.getElementType().isInteger(32)) + return nullptr; + SmallVector values(vectorType.getNumElements(), value); + return builder.create(loc, type, + builder.getI32VectorAttr(values)); + } + if (type.isInteger(32)) + return builder.create(loc, type, + builder.getI32IntegerAttr(value)); + + return nullptr; +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -92,6 +115,42 @@ } }; +/// Converts math.ctlz to SPIR-V ops. +/// +/// SPIR-V does not have a direct operations for counting leading zeros. If +/// Shader capability is supported, we can leverage GLSL FindUMsb to calculate +/// it. +class CountLeadingZerosPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = getTypeConverter()->convertType(countOp.getType()); + if (!type) + return failure(); + + // We can only support 32-bit integer types for now. + unsigned bitwidth = 0; + if (type.isa()) + bitwidth = type.getIntOrFloatBitWidth(); + if (auto vectorType = type.dyn_cast()) + bitwidth = vectorType.getElementTypeBitWidth(); + if (bitwidth != 32) + return failure(); + + Location loc = countOp.getLoc(); + Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); + Value msb = + rewriter.create(loc, adaptor.getOperand()); + // We need to subtract from 31 given that the index is from the least + // significant bit. + rewriter.replaceOpWithNewOp(countOp, val31, msb); + return success(); + } +}; + /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to @@ -148,7 +207,8 @@ // GLSL patterns patterns - .add, ExpM1OpPattern, + .add, + ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp @@ -36,6 +36,17 @@ SPIRVTypeConverter typeConverter(targetAttr); + // Use UnrealizedConversionCast as the bridge so that we don't need to pull + // in patterns for other dialects. + auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto cast = builder.create(loc, type, inputs); + return Optional(cast.getResult(0)); + }; + typeConverter.addSourceMaterialization(addUnrealizedCast); + typeConverter.addTargetMaterialization(addUnrealizedCast); + target->addLegalOp(); + RewritePatternSet patterns(context); populateMathToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s -module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { +module attributes { + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> +} { // CHECK-LABEL: @float32_unary_scalar func.func @float32_unary_scalar(%arg0: f32) { @@ -91,4 +93,56 @@ return } +// CHECK-LABEL: @ctlz_scalar +// CHECK-SAME: (%[[VAL:.+]]: i32) +func.func @ctlz_scalar(%val: i32) -> i32 { + // CHECK: %[[V31:.+]] = spv.Constant 31 : i32 + // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32 + // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32 + // CHECK: return %[[SUB]] + %0 = math.ctlz %val : i32 + return %0 : i32 +} + +// CHECK-LABEL: @ctlz_vector1 +func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> { + // CHECK: spv.GLSL.FindUMsb + // CHECK: spv.ISub + %0 = math.ctlz %val : vector<1xi32> + return %0 : vector<1xi32> +} + +// CHECK-LABEL: @ctlz_vector2 +// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>) +func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> { + // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32> + // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32> + // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32> + // CHECK: return %[[SUB]] + %0 = math.ctlz %val : vector<2xi32> + return %0 : vector<2xi32> +} + +} // end module + +// ----- + +module attributes { + spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: @ctlz_scalar +func.func @ctlz_scalar(%val: i64) -> i64 { + // CHECK: math.ctlz + %0 = math.ctlz %val : i64 + return %0 : i64 +} + +// CHECK-LABEL: @ctlz_vector2 +func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> { + // CHECK: math.ctlz + %0 = math.ctlz %val : vector<2xi16> + return %0 : vector<2xi16> +} + } // end module diff --git a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir @@ -494,10 +494,34 @@ return } -// ----- func.func @fmix_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) -> () { // CHECK: {{%.*}} = spv.GLSL.FMix {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32> -> vector<3xf32> %0 = spv.GLSL.FMix %arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32> -> vector<3xf32> return } +// ----- + +//===----------------------------------------------------------------------===// +// spv.GLSL.Exp +//===----------------------------------------------------------------------===// + +func.func @findumsb(%arg0 : i32) -> () { + // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32 + %2 = spv.GLSL.FindUMsb %arg0 : i32 + return +} + +func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () { + // CHECK: spv.GLSL.FindUMsb {{%.*}} : vector<3xi32> + %2 = spv.GLSL.FindUMsb %arg0 : vector<3xi32> + return +} + +// ----- + +func.func @findumsb(%arg0 : i64) -> () { + // expected-error @+1 {{operand #0 must be Int32 or vector of Int32}} + %2 = spv.GLSL.FindUMsb %arg0 : i64 + return +} diff --git a/mlir/test/Target/SPIRV/glsl-ops.mlir b/mlir/test/Target/SPIRV/glsl-ops.mlir --- a/mlir/test/Target/SPIRV/glsl-ops.mlir +++ b/mlir/test/Target/SPIRV/glsl-ops.mlir @@ -75,4 +75,10 @@ %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32 spv.Return } + + spv.func @findumsb(%arg0 : i32) "None" { + // CHECK: spv.GLSL.FindUMsb {{%.*}} : i32 + %2 = spv.GLSL.FindUMsb %arg0 : i32 + spv.Return + } }