diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -364,10 +364,11 @@ }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with access to an LLVMTypeConverter. -class LLVMOpLowering : public ConversionPattern { +class ConvertToLLVMPattern : public ConversionPattern { public: - LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); + ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1); protected: /// Reference to the type converter, with potential extensions. diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -11,7 +11,6 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" - #include "llvm/ADT/StringSwitch.h" namespace mlir { @@ -22,7 +21,7 @@ // `indexBitwidth`, sign-extend or truncate the resulting value to match the // bitwidth expected by the consumers of the value. template -struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering { +struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern { private: enum dimension { X = 0, Y = 1, Z = 2, invalid }; unsigned indexBitwidth; @@ -42,8 +41,8 @@ public: explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(Op::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(Op::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), indexBitwidth(getIndexBitWidth(lowering_)) {} // Convert the kernel arguments to an LLVM type, preserve the rest. diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -26,12 +26,12 @@ /// will be transformed into /// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float template -struct OpToFuncCallLowering : public LLVMOpLowering { +struct OpToFuncCallLowering : public ConvertToLLVMPattern { public: explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, StringRef f64Func) - : LLVMOpLowering(SourceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(SourceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), f32Func(f32Func), f64Func(f64Func) {} PatternMatchResult diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" - #include "llvm/Support/FormatVariadic.h" #include "../GPUCommon/IndexIntrinsicsOpLowering.h" @@ -49,13 +48,13 @@ }; /// Converts all_reduce op to LLVM/NVVM ops. -struct GPUAllReduceOpLowering : public LLVMOpLowering { +struct GPUAllReduceOpLowering : public ConvertToLLVMPattern { using AccumulatorFactory = std::function; explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(gpu::AllReduceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} PatternMatchResult @@ -463,10 +462,10 @@ static constexpr int kWarpSize = 32; }; -struct GPUShuffleOpLowering : public LLVMOpLowering { +struct GPUShuffleOpLowering : public ConvertToLLVMPattern { explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(gpu::ShuffleOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_) {} + : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} /// Lowers a shuffle to the corresponding NVVM op. /// @@ -521,11 +520,11 @@ } }; -struct GPUFuncOpLowering : LLVMOpLowering { +struct GPUFuncOpLowering : ConvertToLLVMPattern { explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter) - : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} + : ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -656,11 +655,11 @@ } }; -struct GPUReturnOpLowering : public LLVMOpLowering { +struct GPUReturnOpLowering : public ConvertToLLVMPattern { GPUReturnOpLowering(LLVMTypeConverter &typeConverter) - : LLVMOpLowering(gpu::ReturnOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} + : ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" @@ -32,7 +33,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" - #include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" @@ -136,10 +136,10 @@ }; // RangeOp creates a new range descriptor. -class RangeOpConversion : public LLVMOpLowering { +class RangeOpConversion : public ConvertToLLVMPattern { public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -164,11 +164,12 @@ // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. -class ReshapeOpConversion : public LLVMOpLowering { +class ReshapeOpConversion : public ConvertToLLVMPattern { public: explicit ReshapeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, + lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -211,10 +212,10 @@ /// the parent view. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.slice op is replaced by the alloca'ed pointer. -class SliceOpConversion : public LLVMOpLowering { +class SliceOpConversion : public ConvertToLLVMPattern { public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -302,11 +303,12 @@ /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.transpose op is replaced by the alloca'ed pointer. -class TransposeOpConversion : public LLVMOpLowering { +class TransposeOpConversion : public ConvertToLLVMPattern { public: explicit TransposeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, + lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -346,10 +348,10 @@ }; // YieldOp produces and LLVM::ReturnOp. -class YieldOpConversion : public LLVMOpLowering { +class YieldOpConversion : public ConvertToLLVMPattern { public: explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" + #include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -25,7 +26,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" - #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" @@ -375,9 +375,10 @@ .Default([](Type) { return Type(); }); } -LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter_, - PatternBenefit benefit) +ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, + MLIRContext *context, + LLVMTypeConverter &typeConverter_, + PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), typeConverter(typeConverter_) {} @@ -703,13 +704,13 @@ // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. template -class LLVMLegalizationPattern : public LLVMOpLowering { +class LLVMLegalizationPattern : public ConvertToLLVMPattern { public: // Construct a conversion pattern. explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &typeConverter_) - : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), - typeConverter_), + : ConvertToLLVMPattern(SourceOp::getOperationName(), + dialect_.getContext(), typeConverter_), dialect(dialect_) {} // Get the LLVM IR dialect. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" + #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -24,7 +25,6 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" - #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -126,12 +126,12 @@ namespace { -class VectorBroadcastOpConversion : public LLVMOpLowering { +class VectorBroadcastOpConversion : public ConvertToLLVMPattern { public: explicit VectorBroadcastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -275,12 +275,12 @@ } }; -class VectorReductionOpConversion : public LLVMOpLowering { +class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ReductionOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -343,12 +343,12 @@ }; // TODO(ajcbik): merge Reduction and ReductionV2 -class VectorReductionV2OpConversion : public LLVMOpLowering { +class VectorReductionV2OpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionV2OpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ReductionV2Op::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -369,12 +369,12 @@ } }; -class VectorShuffleOpConversion : public LLVMOpLowering { +class VectorShuffleOpConversion : public ConvertToLLVMPattern { public: explicit VectorShuffleOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -427,12 +427,12 @@ } }; -class VectorExtractElementOpConversion : public LLVMOpLowering { +class VectorExtractElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -452,12 +452,12 @@ } }; -class VectorExtractOpConversion : public LLVMOpLowering { +class VectorExtractOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -521,12 +521,12 @@ /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) /// -> !llvm<"<8 x float>"> /// ``` -class VectorFMAOp1DConversion : public LLVMOpLowering { +class VectorFMAOp1DConversion : public ConvertToLLVMPattern { public: explicit VectorFMAOp1DConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::FMAOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -542,12 +542,12 @@ } }; -class VectorInsertElementOpConversion : public LLVMOpLowering { +class VectorInsertElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -567,12 +567,12 @@ } }; -class VectorInsertOpConversion : public LLVMOpLowering { +class VectorInsertOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::InsertOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -815,12 +815,12 @@ } }; -class VectorOuterProductOpConversion : public LLVMOpLowering { +class VectorOuterProductOpConversion : public ConvertToLLVMPattern { public: explicit VectorOuterProductOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -867,12 +867,12 @@ } }; -class VectorTypeCastOpConversion : public LLVMOpLowering { +class VectorTypeCastOpConversion : public ConvertToLLVMPattern { public: explicit VectorTypeCastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -955,12 +955,12 @@ } }; -class VectorPrintOpConversion : public LLVMOpLowering { +class VectorPrintOpConversion : public ConvertToLLVMPattern { public: explicit VectorPrintOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::PrintOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, + typeConverter) {} // Proof-of-concept lowering implementation that relies on a small // runtime support library, which only needs to provide a few