diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -1,7 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS StandardToSPIRV.td) -mlir_tablegen(StandardToSPIRV.cpp.inc -gen-rewriters) -add_public_tablegen_target(MLIRStandardToSPIRVIncGen) - add_mlir_conversion_library(MLIRStandardToSPIRVTransforms ConvertStandardToSPIRV.cpp ConvertStandardToSPIRVPass.cpp @@ -10,9 +6,6 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR - - DEPENDS - MLIRStandardToSPIRVIncGen ) target_link_libraries(MLIRStandardToSPIRVTransforms diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -18,6 +18,9 @@ #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "std-to-spirv-pattern" using namespace mlir; @@ -34,6 +37,66 @@ return false; } +/// Converts the given `srcAttr` into a boolean attribute if it holds a integral +/// value. Returns null attribute if conversion fails. +static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { + if (auto boolAttr = srcAttr.dyn_cast()) + return boolAttr; + if (auto intAttr = srcAttr.dyn_cast()) + return builder.getBoolAttr(intAttr.getValue().getBoolValue()); + return BoolAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if conversion fails. +static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, + Builder builder) { + // If the source number uses less active bits than the target bitwidth, then + // it should be safe to convert. + if (srcAttr.getValue().isIntN(dstType.getWidth())) + return builder.getIntegerAttr(dstType, srcAttr.getInt()); + + // XXX: Try again by interpreting the source number as a signed value. + // Although integers in the standard dialect are signless, they can represent + // a signed number. It's the operation decides how to interpret. This is + // dangerous, but it seems there is no good way of handling this if we still + // want to change the bitwidth. Emit a message at least. + if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { + auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" + << dstAttr << "' for type '" << dstType << "'\n"); + return dstAttr; + } + + LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr + << "' illegal: cannot fit into target type '" + << dstType << "'\n"); + return IntegerAttr(); +} + +/// Converts the given `srcAttr` to a new attribute of the given `dstType`. +/// Returns null attribute if `dstType` is not 32-bit or conversion fails. +static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, + Builder builder) { + // Only support converting to float for now. + if (!dstType.isF32()) + return FloatAttr(); + + // Try to convert the source floating-point number to single precision. + APFloat dstVal = srcAttr.getValue(); + bool losesInfo = false; + APFloat::opStatus status = + dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); + if (status != APFloat::opOK || losesInfo) { + LLVM_DEBUG(llvm::dbgs() + << srcAttr << " illegal: cannot fit into converted type '" + << dstType << "'\n"); + return FloatAttr(); + } + + return builder.getF32FloatAttr(dstVal.convertToFloat()); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -97,7 +160,7 @@ using SPIRVOpLowering::SPIRVOpLowering; LogicalResult - matchAndRewrite(ConstantOp constCompositeOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -107,7 +170,7 @@ using SPIRVOpLowering::SPIRVOpLowering; LogicalResult - matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -211,31 +274,84 @@ //===----------------------------------------------------------------------===// LogicalResult ConstantCompositeOpPattern::matchAndRewrite( - ConstantOp constCompositeOp, ArrayRef operands, + ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto compositeType = - constCompositeOp.getResult().getType().dyn_cast(); - if (!compositeType) + auto srcType = constOp.getType().dyn_cast(); + if (!srcType) return failure(); - auto spirvCompositeType = typeConverter.convertType(compositeType); - if (!spirvCompositeType) + // std.constant should only have vector or tenor types. + assert(srcType.isa() || srcType.isa()); + + auto dstType = typeConverter.convertType(srcType); + if (!dstType) return failure(); - auto linearizedElements = - constCompositeOp.value().dyn_cast(); - if (!linearizedElements) + auto dstElementsAttr = constOp.value().dyn_cast(); + ShapedType dstAttrType = dstElementsAttr.getType(); + if (!dstElementsAttr) return failure(); - // If composite type has rank greater than one, then perform linearization. - if (compositeType.getRank() > 1) { - auto linearizedType = RankedTensorType::get(compositeType.getNumElements(), - compositeType.getElementType()); - linearizedElements = linearizedElements.reshape(linearizedType); + // If the composite type has more than one dimensions, perform linearization. + if (srcType.getRank() > 1) { + if (srcType.isa()) { + dstAttrType = RankedTensorType::get(srcType.getNumElements(), + srcType.getElementType()); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + } else { + // TODO(antiagainst): add support for large vectors. + return failure(); + } + } + + Type srcElemType = srcType.getElementType(); + Type dstElemType; + // Tensor types are converted to SPIR-V array types; vector types are + // converted to SPIR-V vector/array types. + if (auto arrayType = dstType.dyn_cast()) + dstElemType = arrayType.getElementType(); + else + dstElemType = dstType.cast().getElementType(); + + // If the source and destination element types are different, perform + // attribute conversion. + if (srcElemType != dstElemType) { + SmallVector elements; + if (srcElemType.isa()) { + for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { + FloatAttr dstAttr = convertFloatAttr( + srcAttr.cast(), dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } else if (srcElemType.isInteger(1)) { + return failure(); + } else { + for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { + IntegerAttr dstAttr = + convertIntegerAttr(srcAttr.cast(), + dstElemType.cast(), rewriter); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + } + + // Unfortunately, we cannot use dialect-specific types for element + // attributes; element attributes only works with standard types. So we need + // to prepare another converted standard types for the destination elements + // attribute. + if (dstAttrType.isa()) + dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); + else + dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); + + dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } - rewriter.replaceOpWithNewOp( - constCompositeOp, spirvCompositeType, linearizedElements); + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); return success(); } @@ -244,32 +360,52 @@ //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( - ConstantOp constIndexOp, ArrayRef operands, + ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!constIndexOp.getResult().getType().isa()) { + Type srcType = constOp.getType(); + if (!srcType.isIntOrIndexOrFloat()) return failure(); - } - // The attribute has index type which is not directly supported in - // SPIR-V. Get the integer value and create a new IntegerAttr. - auto constAttr = constIndexOp.value().dyn_cast(); - if (!constAttr) { + + Type dstType = typeConverter.convertType(srcType); + if (!dstType) return failure(); + + // Floating-point types. + if (srcType.isa()) { + auto srcAttr = constOp.value().cast(); + auto dstAttr = srcAttr; + + // Floating-point types not supported in the target environment are all + // converted to float type. + if (srcType != dstType) { + dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); } - // Use the bitwidth set in the value attribute to decide the result type - // of the SPIR-V constant operation since SPIR-V does not support index - // types. - auto constVal = constAttr.getValue(); - auto constValType = constAttr.getType().dyn_cast(); - if (!constValType) { - return failure(); + // Bool type. + if (srcType.isInteger(1)) { + // std.constant can use 0/1 instead of true/false for i1 values. We need to + // handle that here. + auto dstAttr = convertBoolAttr(constOp.value(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); } - auto spirvConstType = - typeConverter.convertType(constIndexOp.getResult().getType()); - auto spirvConstVal = - rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); - rewriter.replaceOpWithNewOp(constIndexOp, spirvConstType, - spirvConstVal); + + // IndexType or IntegerType. Index values are converted to 32-bit integer + // values when converting to SPIR-V. + auto srcAttr = constOp.value().cast(); + auto dstAttr = + convertIntegerAttr(srcAttr, dstType.cast(), rewriter); + if (!dstAttr) + return failure(); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } @@ -431,17 +567,10 @@ // Pattern population //===----------------------------------------------------------------------===// -namespace { -/// Import the Standard Ops to SPIR-V Patterns. -#include "StandardToSPIRV.cpp.inc" -} // namespace - namespace mlir { void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - // Add patterns that lower operations into SPIR-V dialect. - populateWithGenerated(context, &patterns); patterns.insert< BinaryOpPattern, BinaryOpPattern, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td deleted file mode 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ /dev/null @@ -1,25 +0,0 @@ -//==- StandardToSPIRV.td - Standard Ops to SPIR-V Patterns ---*- tablegen -*==// - -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines Patterns to lower standard ops to SPIR-V. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_TD -#define MLIR_CONVERSION_STANDARDTOSPIRV_TD - -include "mlir/Dialect/StandardOps/IR/Ops.td" -include "mlir/Dialect/SPIRV/SPIRVOps.td" - -// Constant Op -// TODO(ravishankarm): Handle lowering other constant types. -def : Pat<(ConstantOp:$result $valueAttr), - (SPV_ConstantOp $valueAttr), - [(SPV_ScalarOrVector $result)]>; - -#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -280,9 +280,9 @@ func @constant() { // CHECK: spv.constant true %0 = constant true - // CHECK: spv.constant 42 : i64 - %1 = constant 42 - // CHECK: spv.constant {{[0-9]*\.[0-9]*e?-?[0-9]*}} : f32 + // CHECK: spv.constant 42 : i32 + %1 = constant 42 : i32 + // CHECK: spv.constant 5.000000e-01 : f32 %2 = constant 0.5 : f32 // CHECK: spv.constant dense<[2, 3]> : vector<2xi32> %3 = constant dense<[2, 3]> : vector<2xi32> @@ -303,6 +303,114 @@ return } +// CHECK-LABEL: @constant_16bit +func @constant_16bit() { + // CHECK: spv.constant 4 : i16 + %0 = constant 4 : i16 + // CHECK: spv.constant 5.000000e+00 : f16 + %1 = constant 5.0 : f16 + // CHECK: spv.constant dense<[2, 3]> : vector<2xi16> + %2 = constant dense<[2, 3]> : vector<2xi16> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16 [2]> + %3 = constant dense<4.0> : tensor<5xf16> + return +} + +// CHECK-LABEL: @constant_64bit +func @constant_64bit() { + // CHECK: spv.constant 4 : i64 + %0 = constant 4 : i64 + // CHECK: spv.constant 5.000000e+00 : f64 + %1 = constant 5.0 : f64 + // CHECK: spv.constant dense<[2, 3]> : vector<2xi64> + %2 = constant dense<[2, 3]> : vector<2xi64> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64 [8]> + %3 = constant dense<4.0> : tensor<5xf64> + return +} + +} // end module + +// ----- + +// Check that constants are converted to 32-bit when no special capability. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @constant_16bit +func @constant_16bit() { + // CHECK: spv.constant 4 : i32 + %0 = constant 4 : i16 + // CHECK: spv.constant 5.000000e+00 : f32 + %1 = constant 5.0 : f16 + // CHECK: spv.constant dense<[2, 3]> : vector<2xi32> + %2 = constant dense<[2, 3]> : vector<2xi16> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]> + %3 = constant dense<4.0> : tensor<5xf16> + // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]> + %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> + return +} + +// CHECK-LABEL: @constant_64bit +func @constant_64bit() { + // CHECK: spv.constant 4 : i32 + %0 = constant 4 : i64 + // CHECK: spv.constant 5.000000e+00 : f32 + %1 = constant 5.0 : f64 + // CHECK: spv.constant dense<[2, 3]> : vector<2xi32> + %2 = constant dense<[2, 3]> : vector<2xi64> + // CHECK: spv.constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32 [4]> + %3 = constant dense<4.0> : tensor<5xf64> + // CHECK: spv.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32 [4]> + %4 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16> + return +} + +// CHECK-LABEL: @corner_cases +func @corner_cases() { + // CHECK: %{{.*}} = spv.constant -1 : i32 + %0 = constant 4294967295 : i64 // 2^32 - 1 + // CHECK: %{{.*}} = spv.constant 2147483647 : i32 + %1 = constant 2147483647 : i64 // 2^31 - 1 + // CHECK: %{{.*}} = spv.constant -2147483648 : i32 + %2 = constant 2147483648 : i64 // 2^31 + // CHECK: %{{.*}} = spv.constant -2147483648 : i32 + %3 = constant -2147483648 : i64 // -2^31 + + // CHECK: %{{.*}} = spv.constant -1 : i32 + %5 = constant -1 : i64 + // CHECK: %{{.*}} = spv.constant -2 : i32 + %6 = constant -2 : i64 + // CHECK: %{{.*}} = spv.constant -1 : i32 + %7 = constant -1 : index + // CHECK: %{{.*}} = spv.constant -2 : i32 + %8 = constant -2 : index + + + // CHECK: spv.constant false + %9 = constant 0 : i1 + // CHECK: spv.constant true + %10 = constant 1 : i1 + + return +} + +// CHECK-LABEL: @unsupported_cases +func @unsupported_cases() { + // CHECK: %{{.*}} = constant 4294967296 : i64 + %0 = constant 4294967296 : i64 // 2^32 + // CHECK: %{{.*}} = constant -2147483649 : i64 + %1 = constant -2147483649 : i64 // -2^31 - 1 + // CHECK: %{{.*}} = constant 1.0000000000000002 : f64 + %2 = constant 0x3FF0000000000001 : f64 // smallest number > 1 + return +} + } // end module // -----