diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -364,10 +364,11 @@ }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides /// conversion patterns with access to an LLVMTypeConverter. -class LLVMOpLowering : public ConversionPattern { +class ConvertToLLVMPattern : public ConversionPattern { public: - LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); + ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, + LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1); protected: /// Reference to the type converter, with potential extensions. diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -11,7 +11,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" - +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "llvm/ADT/StringSwitch.h" namespace mlir { @@ -22,7 +22,7 @@ // `indexBitwidth`, sign-extend or truncate the resulting value to match the // bitwidth expected by the consumers of the value. template -struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering { +struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern { private: enum dimension { X = 0, Y = 1, Z = 2, invalid }; unsigned indexBitwidth; @@ -42,8 +42,8 @@ public: explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(Op::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(Op::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), indexBitwidth(getIndexBitWidth(lowering_)) {} // Convert the kernel arguments to an LLVM type, preserve the rest. diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" namespace mlir { @@ -26,12 +27,12 @@ /// will be transformed into /// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float template -struct OpToFuncCallLowering : public LLVMOpLowering { +struct OpToFuncCallLowering : public ConvertToLLVMPattern { public: explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, StringRef f64Func) - : LLVMOpLowering(SourceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(SourceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), f32Func(f32Func), f64Func(f64Func) {} PatternMatchResult diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -11,20 +11,18 @@ // //===----------------------------------------------------------------------===// +#include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" - #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" - +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "llvm/Support/FormatVariadic.h" -#include "../GPUCommon/IndexIntrinsicsOpLowering.h" -#include "../GPUCommon/OpToFuncCallLowering.h" - using namespace mlir; namespace { @@ -49,13 +47,13 @@ }; /// Converts all_reduce op to LLVM/NVVM ops. -struct GPUAllReduceOpLowering : public LLVMOpLowering { +struct GPUAllReduceOpLowering : public ConvertToLLVMPattern { using AccumulatorFactory = std::function; explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), + : ConvertToLLVMPattern(gpu::AllReduceOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_), int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} PatternMatchResult @@ -463,10 +461,10 @@ static constexpr int kWarpSize = 32; }; -struct GPUShuffleOpLowering : public LLVMOpLowering { +struct GPUShuffleOpLowering : public ConvertToLLVMPattern { explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(gpu::ShuffleOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_) {} + : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(), + lowering_.getDialect()->getContext(), lowering_) {} /// Lowers a shuffle to the corresponding NVVM op. /// @@ -521,11 +519,11 @@ } }; -struct GPUFuncOpLowering : LLVMOpLowering { +struct GPUFuncOpLowering : ConvertToLLVMPattern { explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter) - : LLVMOpLowering(gpu::GPUFuncOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} + : ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -656,11 +654,11 @@ } }; -struct GPUReturnOpLowering : public LLVMOpLowering { +struct GPUReturnOpLowering : public ConvertToLLVMPattern { GPUReturnOpLowering(LLVMTypeConverter &typeConverter) - : LLVMOpLowering(gpu::ReturnOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} + : ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(), + typeConverter.getDialect()->getContext(), + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" @@ -32,7 +33,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" - +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "llvm/ADT/SetVector.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" @@ -136,10 +137,10 @@ }; // RangeOp creates a new range descriptor. -class RangeOpConversion : public LLVMOpLowering { +class RangeOpConversion : public ConvertToLLVMPattern { public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -164,11 +165,12 @@ // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. -class ReshapeOpConversion : public LLVMOpLowering { +class ReshapeOpConversion : public ConvertToLLVMPattern { public: explicit ReshapeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, + lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -211,10 +213,10 @@ /// the parent view. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.slice op is replaced by the alloca'ed pointer. -class SliceOpConversion : public LLVMOpLowering { +class SliceOpConversion : public ConvertToLLVMPattern { public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -302,11 +304,12 @@ /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The linalg.transpose op is replaced by the alloca'ed pointer. -class TransposeOpConversion : public LLVMOpLowering { +class TransposeOpConversion : public ConvertToLLVMPattern { public: explicit TransposeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, + lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -346,10 +349,10 @@ }; // YieldOp produces and LLVM::ReturnOp. -class YieldOpConversion : public LLVMOpLowering { +class YieldOpConversion : public ConvertToLLVMPattern { public: explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(YieldOp::getOperationName(), context, lowering_) {} + : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" + #include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -25,7 +26,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" - +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" @@ -375,9 +376,10 @@ .Default([](Type) { return Type(); }); } -LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter_, - PatternBenefit benefit) +ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, + MLIRContext *context, + LLVMTypeConverter &typeConverter_, + PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), typeConverter(typeConverter_) {} @@ -703,13 +705,13 @@ // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. template -class LLVMLegalizationPattern : public LLVMOpLowering { +class LLVMLegalizationPattern : public ConvertToLLVMPattern { public: // Construct a conversion pattern. explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &typeConverter_) - : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), - typeConverter_), + : ConvertToLLVMPattern(SourceOp::getOperationName(), + dialect_.getContext(), typeConverter_), dialect(dialect_) {} // Get the LLVM IR dialect. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" + #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -24,7 +25,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" - +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -126,12 +127,12 @@ namespace { -class VectorBroadcastOpConversion : public LLVMOpLowering { +class VectorBroadcastOpConversion : public ConvertToLLVMPattern { public: explicit VectorBroadcastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::BroadcastOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -275,12 +276,12 @@ } }; -class VectorReductionOpConversion : public LLVMOpLowering { +class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ReductionOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -343,12 +344,12 @@ }; // TODO(ajcbik): merge Reduction and ReductionV2 -class VectorReductionV2OpConversion : public LLVMOpLowering { +class VectorReductionV2OpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionV2OpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ReductionV2Op::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -369,12 +370,12 @@ } }; -class VectorShuffleOpConversion : public LLVMOpLowering { +class VectorShuffleOpConversion : public ConvertToLLVMPattern { public: explicit VectorShuffleOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ShuffleOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -427,12 +428,12 @@ } }; -class VectorExtractElementOpConversion : public LLVMOpLowering { +class VectorExtractElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -452,12 +453,12 @@ } }; -class VectorExtractOpConversion : public LLVMOpLowering { +class VectorExtractOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::ExtractOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -521,12 +522,12 @@ /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) /// -> !llvm<"<8 x float>"> /// ``` -class VectorFMAOp1DConversion : public LLVMOpLowering { +class VectorFMAOp1DConversion : public ConvertToLLVMPattern { public: explicit VectorFMAOp1DConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::FMAOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -542,12 +543,12 @@ } }; -class VectorInsertElementOpConversion : public LLVMOpLowering { +class VectorInsertElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::InsertElementOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -567,12 +568,12 @@ } }; -class VectorInsertOpConversion : public LLVMOpLowering { +class VectorInsertOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::InsertOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -815,12 +816,12 @@ } }; -class VectorOuterProductOpConversion : public LLVMOpLowering { +class VectorOuterProductOpConversion : public ConvertToLLVMPattern { public: explicit VectorOuterProductOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(), + context, typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -867,12 +868,12 @@ } }; -class VectorTypeCastOpConversion : public LLVMOpLowering { +class VectorTypeCastOpConversion : public ConvertToLLVMPattern { public: explicit VectorTypeCastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::TypeCastOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, + typeConverter) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -955,12 +956,12 @@ } }; -class VectorPrintOpConversion : public LLVMOpLowering { +class VectorPrintOpConversion : public ConvertToLLVMPattern { public: explicit VectorPrintOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) - : LLVMOpLowering(vector::PrintOp::getOperationName(), context, - typeConverter) {} + : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, + typeConverter) {} // Proof-of-concept lowering implementation that relies on a small // runtime support library, which only needs to provide a few