Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
Show All 12 Lines | |||||
#include "mlir/Dialect/SPIRV/LayoutUtils.h" | #include "mlir/Dialect/SPIRV/LayoutUtils.h" | ||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h" | #include "mlir/Dialect/SPIRV/SPIRVDialect.h" | ||||
#include "mlir/Dialect/SPIRV/SPIRVLowering.h" | #include "mlir/Dialect/SPIRV/SPIRVLowering.h" | ||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h" | #include "mlir/Dialect/SPIRV/SPIRVOps.h" | ||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
#include "mlir/IR/AffineMap.h" | #include "mlir/IR/AffineMap.h" | ||||
#include "mlir/Support/LogicalResult.h" | #include "mlir/Support/LogicalResult.h" | ||||
#include "llvm/ADT/SetVector.h" | #include "llvm/ADT/SetVector.h" | ||||
#include "llvm/Support/Debug.h" | |||||
#define DEBUG_TYPE "std-to-spirv-pattern" | |||||
using namespace mlir; | using namespace mlir; | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Utility functions | // Utility functions | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Returns true if the given `type` is a boolean scalar or vector type. | /// Returns true if the given `type` is a boolean scalar or vector type. | ||||
static bool isBoolScalarOrVector(Type type) { | static bool isBoolScalarOrVector(Type type) { | ||||
if (type.isInteger(1)) | if (type.isInteger(1)) | ||||
return true; | return true; | ||||
if (auto vecType = type.dyn_cast<VectorType>()) | if (auto vecType = type.dyn_cast<VectorType>()) | ||||
return vecType.getElementType().isInteger(1); | return vecType.getElementType().isInteger(1); | ||||
return false; | return false; | ||||
} | } | ||||
/// Converts the given `srcAttr` into a boolean attribute if it holds a integral | |||||
/// value. Returns null attribute if conversion fails. | |||||
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { | |||||
if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) | |||||
return boolAttr; | |||||
if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) | |||||
return builder.getBoolAttr(intAttr.getValue().getBoolValue()); | |||||
return BoolAttr(); | |||||
} | |||||
/// Converts the given `srcAttr` to a new attribute of the given `dstType`. | |||||
/// Returns null attribute if conversion fails. | |||||
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, | |||||
Builder builder) { | |||||
// If the source number uses less active bits than the target bitwidth, then | |||||
mravishankar: +1 for this is dangerous! | |||||
Sadly although we say index cannot be negative, but AFAICT, we are actually relying on negative index values in passes like https://github.com/llvm/llvm-project/blob/bd0ca2627cfa1acde2a272347ed55d88a9751869/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp#L126. That is the reason of this statement. antiagainst: Sadly although we say index cannot be negative, but AFAICT, we are actually relying on negative… | |||||
// it should be safe to convert. | |||||
if (srcAttr.getValue().isIntN(dstType.getWidth())) | |||||
return builder.getIntegerAttr(dstType, srcAttr.getInt()); | |||||
// XXX: Try again by interpreting the source number as a signed value. | |||||
// Although integers in the standard dialect are signless, they can represent | |||||
// a signed number. It's the operation decides how to interpret. This is | |||||
// dangerous, but it seems there is no good way of handling this if we still | |||||
// want to change the bitwidth. Emit a message at least. | |||||
if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { | |||||
auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); | |||||
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" | |||||
<< dstAttr << "' for type '" << dstType << "'\n"); | |||||
return dstAttr; | |||||
} | |||||
LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr | |||||
<< "' illegal: cannot fit into target type '" | |||||
<< dstType << "'\n"); | |||||
return IntegerAttr(); | |||||
} | |||||
/// Converts the given `srcAttr` to a new attribute of the given `dstType`. | |||||
/// Returns null attribute if `dstType` is not 32-bit or conversion fails. | |||||
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, | |||||
Builder builder) { | |||||
// Only support converting to float for now. | |||||
if (!dstType.isF32()) | |||||
return FloatAttr(); | |||||
// Try to convert the source floating-point number to single precision. | |||||
APFloat dstVal = srcAttr.getValue(); | |||||
bool losesInfo = false; | |||||
APFloat::opStatus status = | |||||
dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); | |||||
if (status != APFloat::opOK || losesInfo) { | |||||
LLVM_DEBUG(llvm::dbgs() | |||||
<< srcAttr << " illegal: cannot fit into converted type '" | |||||
<< dstType << "'\n"); | |||||
return FloatAttr(); | |||||
} | |||||
return builder.getF32FloatAttr(dstVal.convertToFloat()); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Operation conversion | // Operation conversion | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Note that DRR cannot be used for the patterns in this file: we may need to | // Note that DRR cannot be used for the patterns in this file: we may need to | ||||
// convert type along the way, which requires ConversionPattern. DRR generates | // convert type along the way, which requires ConversionPattern. DRR generates | ||||
// normal RewritePattern. | // normal RewritePattern. | ||||
▲ Show 20 Lines • Show All 47 Lines • ▼ Show 20 Lines | |||||
}; | }; | ||||
/// Converts composite std.constant operation to spv.constant. | /// Converts composite std.constant operation to spv.constant. | ||||
class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> { | class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> { | ||||
public: | public: | ||||
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; | using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands, | matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override; | ConversionPatternRewriter &rewriter) const override; | ||||
}; | }; | ||||
/// Converts scalar std.constant operation to spv.constant. | /// Converts scalar std.constant operation to spv.constant. | ||||
class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> { | class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> { | ||||
public: | public: | ||||
using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; | using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value> operands, | matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override; | ConversionPatternRewriter &rewriter) const override; | ||||
}; | }; | ||||
/// Converts floating-point comparison operations to SPIR-V ops. | /// Converts floating-point comparison operations to SPIR-V ops. | ||||
class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> { | class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> { | ||||
public: | public: | ||||
using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering; | using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering; | ||||
▲ Show 20 Lines • Show All 87 Lines • ▼ Show 20 Lines | |||||
} // namespace | } // namespace | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// ConstantOp with composite type. | // ConstantOp with composite type. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
LogicalResult ConstantCompositeOpPattern::matchAndRewrite( | LogicalResult ConstantCompositeOpPattern::matchAndRewrite( | ||||
ConstantOp constCompositeOp, ArrayRef<Value> operands, | ConstantOp constOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const { | ConversionPatternRewriter &rewriter) const { | ||||
auto compositeType = | auto srcType = constOp.getType().dyn_cast<ShapedType>(); | ||||
constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>(); | if (!srcType) | ||||
if (!compositeType) | |||||
return failure(); | return failure(); | ||||
auto spirvCompositeType = typeConverter.convertType(compositeType); | // std.constant should only have vector or tenor types. | ||||
if (!spirvCompositeType) | assert(srcType.isa<VectorType>() || srcType.isa<RankedTensorType>()); | ||||
auto dstType = typeConverter.convertType(srcType); | |||||
if (!dstType) | |||||
return failure(); | return failure(); | ||||
auto linearizedElements = | auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>(); | ||||
constCompositeOp.value().dyn_cast<DenseElementsAttr>(); | ShapedType dstAttrType = dstElementsAttr.getType(); | ||||
if (!linearizedElements) | if (!dstElementsAttr) | ||||
return failure(); | return failure(); | ||||
// If composite type has rank greater than one, then perform linearization. | // If the composite type has more than one dimensions, perform linearization. | ||||
if (compositeType.getRank() > 1) { | if (srcType.getRank() > 1) { | ||||
auto linearizedType = RankedTensorType::get(compositeType.getNumElements(), | if (srcType.isa<RankedTensorType>()) { | ||||
compositeType.getElementType()); | dstAttrType = RankedTensorType::get(srcType.getNumElements(), | ||||
linearizedElements = linearizedElements.reshape(linearizedType); | srcType.getElementType()); | ||||
dstElementsAttr = dstElementsAttr.reshape(dstAttrType); | |||||
} else { | |||||
// TODO(antiagainst): add support for large vectors. | |||||
return failure(); | |||||
} | |||||
} | |||||
Type srcElemType = srcType.getElementType(); | |||||
Type dstElemType; | |||||
// Tensor types are converted to SPIR-V array types; vector types are | |||||
// converted to SPIR-V vector/array types. | |||||
if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) | |||||
dstElemType = arrayType.getElementType(); | |||||
else | |||||
dstElemType = dstType.cast<VectorType>().getElementType(); | |||||
// If the source and destination element types are different, perform | |||||
// attribute conversion. | |||||
if (srcElemType != dstElemType) { | |||||
SmallVector<Attribute, 8> elements; | |||||
if (srcElemType.isa<FloatType>()) { | |||||
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { | |||||
FloatAttr dstAttr = convertFloatAttr( | |||||
srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter); | |||||
if (!dstAttr) | |||||
return failure(); | |||||
elements.push_back(dstAttr); | |||||
} | |||||
} else if (srcElemType.isInteger(1)) { | |||||
return failure(); | |||||
} else { | |||||
for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { | |||||
IntegerAttr dstAttr = | |||||
convertIntegerAttr(srcAttr.cast<IntegerAttr>(), | |||||
dstElemType.cast<IntegerType>(), rewriter); | |||||
if (!dstAttr) | |||||
return failure(); | |||||
elements.push_back(dstAttr); | |||||
} | |||||
} | } | ||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>( | // Unfortunately, we cannot use dialect-specific types for element | ||||
constCompositeOp, spirvCompositeType, linearizedElements); | // attributes; element attributes only works with standard types. So we need | ||||
// to prepare another converted standard types for the destination elements | |||||
// attribute. | |||||
if (dstAttrType.isa<RankedTensorType>()) | |||||
dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); | |||||
else | |||||
dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); | |||||
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); | |||||
} | |||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, | |||||
dstElementsAttr); | |||||
return success(); | return success(); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// ConstantOp with scalar type. | // ConstantOp with scalar type. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
LogicalResult ConstantScalarOpPattern::matchAndRewrite( | LogicalResult ConstantScalarOpPattern::matchAndRewrite( | ||||
ConstantOp constIndexOp, ArrayRef<Value> operands, | ConstantOp constOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const { | ConversionPatternRewriter &rewriter) const { | ||||
if (!constIndexOp.getResult().getType().isa<IndexType>()) { | Type srcType = constOp.getType(); | ||||
if (!srcType.isIntOrIndexOrFloat()) | |||||
return failure(); | return failure(); | ||||
} | |||||
// The attribute has index type which is not directly supported in | Type dstType = typeConverter.convertType(srcType); | ||||
// SPIR-V. Get the integer value and create a new IntegerAttr. | if (!dstType) | ||||
auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>(); | return failure(); | ||||
if (!constAttr) { | |||||
// Floating-point types. | |||||
if (srcType.isa<FloatType>()) { | |||||
auto srcAttr = constOp.value().cast<FloatAttr>(); | |||||
auto dstAttr = srcAttr; | |||||
// Floating-point types not supported in the target environment are all | |||||
// converted to float type. | |||||
if (srcType != dstType) { | |||||
dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); | |||||
if (!dstAttr) | |||||
return failure(); | return failure(); | ||||
} | } | ||||
// Use the bitwidth set in the value attribute to decide the result type | rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); | ||||
// of the SPIR-V constant operation since SPIR-V does not support index | return success(); | ||||
// types. | } | ||||
auto constVal = constAttr.getValue(); | |||||
auto constValType = constAttr.getType().dyn_cast<IndexType>(); | // Bool type. | ||||
if (!constValType) { | if (srcType.isInteger(1)) { | ||||
// std.constant can use 0/1 instead of true/false for i1 values. We need to | |||||
// handle that here. | |||||
auto dstAttr = convertBoolAttr(constOp.value(), rewriter); | |||||
if (!dstAttr) | |||||
return failure(); | return failure(); | ||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); | |||||
return success(); | |||||
} | } | ||||
auto spirvConstType = | |||||
typeConverter.convertType(constIndexOp.getResult().getType()); | // IndexType or IntegerType. Index values are converted to 32-bit integer | ||||
auto spirvConstVal = | // values when converting to SPIR-V. | ||||
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); | auto srcAttr = constOp.value().cast<IntegerAttr>(); | ||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType, | auto dstAttr = | ||||
spirvConstVal); | convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); | ||||
if (!dstAttr) | |||||
return failure(); | |||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); | |||||
return success(); | return success(); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// CmpFOp | // CmpFOp | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
LogicalResult | LogicalResult | ||||
▲ Show 20 Lines • Show All 145 Lines • ▼ Show 20 Lines | XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands, | ||||
return success(); | return success(); | ||||
} | } | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Pattern population | // Pattern population | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
namespace { | |||||
/// Import the Standard Ops to SPIR-V Patterns. | |||||
#include "StandardToSPIRV.cpp.inc" | |||||
} // namespace | |||||
namespace mlir { | namespace mlir { | ||||
void populateStandardToSPIRVPatterns(MLIRContext *context, | void populateStandardToSPIRVPatterns(MLIRContext *context, | ||||
SPIRVTypeConverter &typeConverter, | SPIRVTypeConverter &typeConverter, | ||||
OwningRewritePatternList &patterns) { | OwningRewritePatternList &patterns) { | ||||
// Add patterns that lower operations into SPIR-V dialect. | |||||
populateWithGenerated(context, &patterns); | |||||
patterns.insert< | patterns.insert< | ||||
BinaryOpPattern<AddFOp, spirv::FAddOp>, | BinaryOpPattern<AddFOp, spirv::FAddOp>, | ||||
BinaryOpPattern<AddIOp, spirv::IAddOp>, | BinaryOpPattern<AddIOp, spirv::IAddOp>, | ||||
BinaryOpPattern<DivFOp, spirv::FDivOp>, | BinaryOpPattern<DivFOp, spirv::FDivOp>, | ||||
BinaryOpPattern<MulFOp, spirv::FMulOp>, | BinaryOpPattern<MulFOp, spirv::FMulOp>, | ||||
BinaryOpPattern<MulIOp, spirv::IMulOp>, | BinaryOpPattern<MulIOp, spirv::IMulOp>, | ||||
BinaryOpPattern<RemFOp, spirv::FRemOp>, | BinaryOpPattern<RemFOp, spirv::FRemOp>, | ||||
BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>, | BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>, | ||||
Show All 18 Lines |
+1 for this is dangerous!