Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
Show First 20 Lines • Show All 87 Lines • ▼ Show 20 Lines | return rewriter.create<LLVM::ConstantOp>( | ||||
loc, dstType, | loc, dstType, | ||||
SplatElementsAttr::get(srcType.cast<ShapedType>(), | SplatElementsAttr::get(srcType.cast<ShapedType>(), | ||||
minusOneIntegerAttribute(srcType, rewriter))); | minusOneIntegerAttribute(srcType, rewriter))); | ||||
} | } | ||||
return rewriter.create<LLVM::ConstantOp>( | return rewriter.create<LLVM::ConstantOp>( | ||||
loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); | loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); | ||||
} | } | ||||
/// Creates `llvm.mlir.constant` with a floating-point value. | |||||
antiagainst: a floating point scalar or vector value | |||||
static Value createFPConstant(Location loc, Type srcType, Type dstType, | |||||
PatternRewriter &rewriter, double value) { | |||||
if (auto vecType = srcType.dyn_cast<VectorType>()) { | |||||
auto floatType = vecType.getElementType().cast<FloatType>(); | |||||
return rewriter.create<LLVM::ConstantOp>( | |||||
loc, dstType, | |||||
SplatElementsAttr::get(vecType, | |||||
rewriter.getFloatAttr(floatType, value))); | |||||
} | |||||
auto floatType = srcType.cast<FloatType>(); | |||||
return rewriter.create<LLVM::ConstantOp>( | |||||
loc, dstType, rewriter.getFloatAttr(floatType, value)); | |||||
} | |||||
/// Utility function for bitfiled ops: | /// Utility function for bitfiled ops: | ||||
/// - `BitFieldInsert` | /// - `BitFieldInsert` | ||||
/// - `BitFieldSExtract` | /// - `BitFieldSExtract` | ||||
/// - `BitFieldUExtract` | /// - `BitFieldUExtract` | ||||
/// Truncates or extends the value. If the bitwidth of the value is the same as | /// Truncates or extends the value. If the bitwidth of the value is the same as | ||||
/// `dstType` bitwidth, the value remains unchanged. | /// `dstType` bitwidth, the value remains unchanged. | ||||
static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, | static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, | ||||
PatternRewriter &rewriter) { | PatternRewriter &rewriter) { | ||||
▲ Show 20 Lines • Show All 476 Lines • ▼ Show 20 Lines | matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands, | ||||
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( | rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( | ||||
operation, dstType, | operation, dstType, | ||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), | rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)), | ||||
operation.operand1(), operation.operand2()); | operation.operand1(), operation.operand2()); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
class InverseSqrtPattern | |||||
: public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> { | |||||
public: | |||||
using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion; | |||||
LogicalResult | |||||
matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter &rewriter) const override { | |||||
auto srcType = op.getType(); | |||||
auto dstType = typeConverter.convertType(srcType); | |||||
if (!dstType) | |||||
return failure(); | |||||
Location loc = op.getLoc(); | |||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); | |||||
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); | |||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); | |||||
return success(); | |||||
} | |||||
}; | |||||
/// Converts `spv.Load` and `spv.Store` to LLVM dialect. | /// Converts `spv.Load` and `spv.Store` to LLVM dialect. | ||||
template <typename SPIRVop> | template <typename SPIRVop> | ||||
class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> { | class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> { | ||||
public: | public: | ||||
using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion; | using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(SPIRVop op, ArrayRef<Value> operands, | matchAndRewrite(SPIRVop op, ArrayRef<Value> operands, | ||||
▲ Show 20 Lines • Show All 203 Lines • ▼ Show 20 Lines | matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands, | ||||
Location loc = tanOp.getLoc(); | Location loc = tanOp.getLoc(); | ||||
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); | Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); | ||||
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); | Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); | ||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); | rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | }; | ||||
/// Convert `spv.Tanh` to | |||||
/// | |||||
/// exp(2x) - 1 | |||||
/// ----------- | |||||
/// exp(2x) + 1 | |||||
/// | |||||
class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> { | |||||
public: | |||||
using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion; | |||||
LogicalResult | |||||
matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands, | |||||
ConversionPatternRewriter &rewriter) const override { | |||||
auto srcType = tanhOp.getType(); | |||||
auto dstType = typeConverter.convertType(srcType); | |||||
if (!dstType) | |||||
return failure(); | |||||
Location loc = tanhOp.getLoc(); | |||||
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); | |||||
Value multiplied = | |||||
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); | |||||
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); | |||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); | |||||
Value numerator = | |||||
rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); | |||||
Value denominator = | |||||
rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); | |||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, | |||||
denominator); | |||||
return success(); | |||||
} | |||||
}; | |||||
class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { | class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { | ||||
public: | public: | ||||
using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; | using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands, | matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
auto srcType = varOp.getType(); | auto srcType = varOp.getType(); | ||||
▲ Show 20 Lines • Show All 215 Lines • ▼ Show 20 Lines | patterns.insert< | ||||
// GLSL extended instruction set ops | // GLSL extended instruction set ops | ||||
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, | DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>, | ||||
DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, | DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>, | ||||
DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, | DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>, | ||||
DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, | DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>, | ||||
DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, | DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>, | ||||
DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, | DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>, | ||||
DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, TanPattern, | DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>, | ||||
InverseSqrtPattern, TanPattern, TanhPattern, | |||||
// Logical ops | // Logical ops | ||||
DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, | DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, | ||||
DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, | DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, | ||||
IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, | IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, | ||||
IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, | IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, | ||||
NotPattern<spirv::LogicalNotOp>, | NotPattern<spirv::LogicalNotOp>, | ||||
Show All 29 Lines |
a floating point scalar or vector value