diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h --- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h +++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h @@ -18,8 +18,7 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns, - MLIRContext *ctx); + OwningRewritePatternList &patterns); /// Create a pass to convert Linalg operations to the LLVMIR dialect. std::unique_ptr> createConvertLinalgToLLVMPass(); diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h --- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h +++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h @@ -19,8 +19,7 @@ class OwningRewritePatternList; /// Populate the given list with patterns that convert from OpenMP to LLVM. -void populateOpenMPToLLVMConversionPatterns(MLIRContext *context, - LLVMTypeConverter &converter, +void populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Create a pass to convert OpenMP operations to the LLVMIR dialect. diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -565,8 +565,8 @@ template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: - ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) + explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertToLLVMPattern(SourceOp::getOperationName(), &typeConverter.getContext(), typeConverter, benefit) {} diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -34,8 +34,7 @@ /// operands as is, preserve attributes. template static LogicalResult -matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, - LLVMTypeConverter &typeConverter, Operation *op, +matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); @@ -73,71 +72,61 @@ // TODO: Patterns are too verbose due to the fact that we have 1 op (e.g. // MaskRndScaleOp) and different possible target ops. It would be better to take // a Functor so that all these conversions become 1-liners. -struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern { - explicit MaskRndScaleOpPS512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, - typeConverter) {} +struct MaskRndScaleOpPS512Conversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MaskRndScaleOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(cast(op)).isF32()) + if (!getSrcVectorElementType(op).isF32()) return failure(); return matchAndRewriteOneToOne( - *this, *getTypeConverter(), op, operands, rewriter); + *getTypeConverter(), op, operands, rewriter); } }; -struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern { - explicit MaskRndScaleOpPD512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, - typeConverter) {} +struct MaskRndScaleOpPD512Conversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MaskRndScaleOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(cast(op)).isF64()) + if (!getSrcVectorElementType(op).isF64()) return failure(); return matchAndRewriteOneToOne( - *this, *getTypeConverter(), op, operands, rewriter); + *getTypeConverter(), op, operands, rewriter); } }; -struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern { - explicit ScaleFOpPS512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, - typeConverter) {} +struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MaskScaleFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(cast(op)).isF32()) + if (!getSrcVectorElementType(op).isF32()) return failure(); return matchAndRewriteOneToOne( - *this, *getTypeConverter(), op, operands, rewriter); + *getTypeConverter(), op, operands, rewriter); } }; -struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern { - explicit ScaleFOpPD512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, - typeConverter) {} +struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(MaskScaleFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(cast(op)).isF64()) + if (!getSrcVectorElementType(op).isF64()) return failure(); return matchAndRewriteOneToOne( - *this, *getTypeConverter(), op, operands, rewriter); + *getTypeConverter(), op, operands, rewriter); } }; } // namespace @@ -145,11 +134,10 @@ /// Populate the given list with patterns that convert from AVX512 to LLVM. void mlir::populateAVX512ToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx, converter); + ScaleFOpPD512Conversion>(converter); // clang-format on } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -18,17 +18,13 @@ namespace mlir { template -struct GPUFuncOpLowering : ConvertToLLVMPattern { - explicit GPUFuncOpLowering(LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(gpu::GPUFuncOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} +struct GPUFuncOpLowering : ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.empty() && "func op is not expected to have operands"); - auto gpuFuncOp = cast(op); Location loc = gpuFuncOp.getLoc(); SmallVector workgroupBuffers; @@ -154,14 +150,11 @@ } }; -struct GPUReturnOpLowering : public ConvertToLLVMPattern { - GPUReturnOpLowering(LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(gpu::ReturnOp::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter) {} +struct GPUReturnOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -21,7 +21,7 @@ // `indexBitwidth`, sign-extend or truncate the resulting value to match the // bitwidth expected by the consumers of the value. template -struct GPUIndexIntrinsicOpLowering : public ConvertToLLVMPattern { +struct GPUIndexIntrinsicOpLowering : public ConvertOpToLLVMPattern { private: enum dimension { X = 0, Y = 1, Z = 2, invalid }; unsigned indexBitwidth; @@ -36,19 +36,17 @@ public: explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(Op::getOperationName(), - typeConverter.getDialect()->getContext(), - typeConverter), + : ConvertOpToLLVMPattern(typeConverter), indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} // Convert the kernel arguments to an LLVM type, preserve the rest. LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(Op op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); MLIRContext *context = rewriter.getContext(); Value newOp; - switch (dimensionToIndex(cast(op))) { + switch (dimensionToIndex(op)) { case X: newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); break; diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -29,16 +29,15 @@ /// will be transformed into /// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float template -struct OpToFuncCallLowering : public ConvertToLLVMPattern { +struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { public: explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, StringRef f64Func) - : ConvertToLLVMPattern(SourceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), - f32Func(f32Func), f64Func(f64Func) {} + : ConvertOpToLLVMPattern(lowering_), f32Func(f32Func), + f64Func(f64Func) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; using LLVM::LLVMType; diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -31,10 +31,8 @@ namespace { -struct GPUShuffleOpLowering : public ConvertToLLVMPattern { - explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_) {} +struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Lowers a shuffle to the corresponding NVVM op. /// @@ -53,7 +51,7 @@ /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : /// !llvm<"{ float, i1 }"> LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(gpu::ShuffleOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); gpu::ShuffleOpAdaptor adaptor(operands); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -126,19 +126,17 @@ }; // RangeOp creates a new range descriptor. -class RangeOpConversion : public ConvertToLLVMPattern { +class RangeOpConversion : public ConvertOpToLLVMPattern { public: - explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(RangeOp rangeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto rangeOp = cast(op); auto rangeDescriptorTy = convertRangeType( rangeOp.getType().cast(), *getTypeConverter()); - edsc::ScopedContext context(rewriter, op->getLoc()); + edsc::ScopedContext context(rewriter, rangeOp->getLoc()); // Fill in an aggregate value of the descriptor. RangeOpAdaptor adaptor(operands); @@ -146,7 +144,7 @@ desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); - rewriter.replaceOp(op, desc); + rewriter.replaceOp(rangeOp, desc); return success(); } }; @@ -154,17 +152,13 @@ // ReshapeOp creates a new view descriptor of the proper rank. // For now, the only conversion supported is for target MemRef with static sizes // and strides. -class ReshapeOpConversion : public ConvertToLLVMPattern { +class ReshapeOpConversion : public ConvertOpToLLVMPattern { public: - explicit ReshapeOpConversion(MLIRContext *context, - LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, - lowering_) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ReshapeOp reshapeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto reshapeOp = cast(op); MemRefType dstType = reshapeOp.getResultType(); if (!dstType.hasStaticShape()) @@ -178,7 +172,7 @@ })) return failure(); - edsc::ScopedContext context(rewriter, op->getLoc()); + edsc::ScopedContext context(rewriter, reshapeOp->getLoc()); ReshapeOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.src()); BaseViewConversionHelper desc(typeConverter->convertType(dstType)); @@ -189,7 +183,7 @@ desc.setConstantSize(en.index(), en.value()); for (auto en : llvm::enumerate(strides)) desc.setConstantStride(en.index(), en.value()); - rewriter.replaceOp(op, {desc}); + rewriter.replaceOp(reshapeOp, {desc}); return success(); } }; @@ -200,19 +194,17 @@ /// and stride corresponding to the region of memory within the bounds of /// the parent view. /// The linalg.slice op is replaced by the alloca'ed pointer. -class SliceOpConversion : public ConvertToLLVMPattern { +class SliceOpConversion : public ConvertOpToLLVMPattern { public: - explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(SliceOp sliceOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - edsc::ScopedContext context(rewriter, op->getLoc()); + edsc::ScopedContext context(rewriter, sliceOp->getLoc()); SliceOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.view()); - auto sliceOp = cast(op); auto memRefType = sliceOp.getBaseViewType(); auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)) .cast(); @@ -248,7 +240,7 @@ // Corner case, no sizes or strides: early return the descriptor. if (sliceOp.getShapedType().getRank() == 0) - return rewriter.replaceOp(op, {desc}), success(); + return rewriter.replaceOp(sliceOp, {desc}), success(); Value zero = llvm_constant( int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); @@ -279,20 +271,18 @@ } } - rewriter.replaceOp(op, {desc}); + rewriter.replaceOp(sliceOp, {desc}); return success(); } }; // YieldOp produces and LLVM::ReturnOp. -class YieldOpConversion : public ConvertToLLVMPattern { +class YieldOpConversion : public ConvertOpToLLVMPattern { public: - explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(linalg::YieldOp::getOperationName(), context, - lowering_) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(linalg::YieldOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return success(); @@ -302,10 +292,9 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - MLIRContext *ctx) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert(ctx, converter); + YieldOpConversion>(converter); // Populate the type conversions for the linalg types. converter.addConversion( @@ -331,7 +320,7 @@ populateVectorToSCFConversionPatterns(patterns, &getContext()); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns(converter, patterns); - populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); + populateLinalgToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalOp(); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -21,34 +21,30 @@ /// expected to either be processed by the conversion infrastructure or already /// contain ops compatible with LLVM dialect types. template -struct RegionOpConversion : public ConvertToLLVMPattern { - explicit RegionOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(OpType::getOperationName(), context, - typeConverter) {} +struct RegionOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(OpType curOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto curOp = cast(op); auto newOp = rewriter.create(curOp.getLoc(), TypeRange(), operands, curOp.getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); - if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter))) + if (failed(rewriter.convertRegionTypes(&newOp.region(), + *this->getTypeConverter()))) return failure(); - rewriter.eraseOp(op); + rewriter.eraseOp(curOp); return success(); } }; } // namespace void mlir::populateOpenMPToLLVMConversionPatterns( - MLIRContext *context, LLVMTypeConverter &converter, - OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert, - RegionOpConversion>(context, converter); + RegionOpConversion>(converter); } namespace { @@ -60,13 +56,12 @@ void ConvertOpenMPToLLVMPass::runOnOperation() { auto module = getOperation(); - MLIRContext *context = &getContext(); // Convert to OpenMP operations with LLVM IR dialect OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); populateStdToLLVMConversionPatterns(converter, patterns); - populateOpenMPToLLVMConversionPatterns(context, converter, patterns); + populateOpenMPToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -296,39 +296,33 @@ /// Conversion pattern for a vector.matrix_multiply. /// This is lowered directly to the proper llvm.intr.matrix.multiply. -class VectorMatmulOpConversion : public ConvertToLLVMPattern { +class VectorMatmulOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorMatmulOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::MatmulOp matmulOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto matmulOp = cast(op); auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(), - adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), - matmulOp.rhs_columns()); + matmulOp, typeConverter->convertType(matmulOp.res().getType()), + adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), + matmulOp.lhs_columns(), matmulOp.rhs_columns()); return success(); } }; /// Conversion pattern for a vector.flat_transpose. /// This is lowered directly to the proper llvm.intr.matrix.transpose. -class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { +class VectorFlatTransposeOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorFlatTransposeOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), - context, typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::FlatTransposeOp transOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto transOp = cast(op); auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp( transOp, typeConverter->convertType(transOp.res().getType()), @@ -338,18 +332,15 @@ }; /// Conversion pattern for a vector.maskedload. -class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { +class VectorMaskedLoadOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorMaskedLoadOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::MaskedLoadOp load, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto load = cast(op); + auto loc = load->getLoc(); auto adaptor = vector::MaskedLoadOpAdaptor(operands); // Resolve alignment. @@ -371,18 +362,15 @@ }; /// Conversion pattern for a vector.maskedstore. -class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { +class VectorMaskedStoreOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorMaskedStoreOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::MaskedStoreOp store, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto store = cast(op); + auto loc = store->getLoc(); auto adaptor = vector::MaskedStoreOpAdaptor(operands); // Resolve alignment. @@ -404,18 +392,15 @@ }; /// Conversion pattern for a vector.gather. -class VectorGatherOpConversion : public ConvertToLLVMPattern { +class VectorGatherOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorGatherOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::GatherOp gather, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto gather = cast(op); + auto loc = gather->getLoc(); auto adaptor = vector::GatherOpAdaptor(operands); // Resolve alignment. @@ -440,18 +425,15 @@ }; /// Conversion pattern for a vector.scatter. -class VectorScatterOpConversion : public ConvertToLLVMPattern { +class VectorScatterOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorScatterOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ScatterOp scatter, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto scatter = cast(op); + auto loc = scatter->getLoc(); auto adaptor = vector::ScatterOpAdaptor(operands); // Resolve alignment. @@ -476,18 +458,15 @@ }; /// Conversion pattern for a vector.expandload. -class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { +class VectorExpandLoadOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorExpandLoadOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ExpandLoadOp expand, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto expand = cast(op); + auto loc = expand->getLoc(); auto adaptor = vector::ExpandLoadOpAdaptor(operands); Value ptr; @@ -497,25 +476,22 @@ auto vType = expand.getResultVectorType(); rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(vType), ptr, adaptor.mask(), + expand, typeConverter->convertType(vType), ptr, adaptor.mask(), adaptor.pass_thru()); return success(); } }; /// Conversion pattern for a vector.compressstore. -class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { +class VectorCompressStoreOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorCompressStoreOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), - context, typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::CompressStoreOp compress, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - auto compress = cast(op); + auto loc = compress->getLoc(); auto adaptor = vector::CompressStoreOpAdaptor(operands); Value ptr; @@ -524,25 +500,23 @@ return failure(); rewriter.replaceOpWithNewOp( - op, adaptor.value(), ptr, adaptor.mask()); + compress, adaptor.value(), ptr, adaptor.mask()); return success(); } }; /// Conversion pattern for all vector reductions. -class VectorReductionOpConversion : public ConvertToLLVMPattern { +class VectorReductionOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorReductionOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter, + explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv, bool reassociateFPRed) - : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, - typeConverter), + : ConvertOpToLLVMPattern(typeConv), reassociateFPReductions(reassociateFPRed) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ReductionOp reductionOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto reductionOp = cast(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); Type llvmType = typeConverter->convertType(eltType); @@ -550,33 +524,33 @@ // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "mul") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "min" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "min") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "max" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "max") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "and") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "or") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else if (kind == "xor") rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); + reductionOp, llvmType, operands[0]); else return failure(); return success(); @@ -590,27 +564,27 @@ // Optional accumulator (or zero). Value acc = operands.size() > 1 ? operands[1] : rewriter.create( - op->getLoc(), llvmType, + reductionOp->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0], + reductionOp, llvmType, acc, operands[0], rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "mul") { // Optional accumulator (or one). Value acc = operands.size() > 1 ? operands[1] : rewriter.create( - op->getLoc(), llvmType, + reductionOp->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0], + reductionOp, llvmType, acc, operands[0], rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "min") - rewriter.replaceOpWithNewOp(op, llvmType, - operands[0]); + rewriter.replaceOpWithNewOp( + reductionOp, llvmType, operands[0]); else if (kind == "max") - rewriter.replaceOpWithNewOp(op, llvmType, - operands[0]); + rewriter.replaceOpWithNewOp( + reductionOp, llvmType, operands[0]); else return failure(); return success(); @@ -621,17 +595,16 @@ }; /// Conversion pattern for a vector.create_mask (1-D only). -class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { +class VectorCreateMaskOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorCreateMaskOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter, + explicit VectorCreateMaskOpConversion(LLVMTypeConverter &typeConv, bool enableIndexOpt) - : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, - typeConverter), + : ConvertOpToLLVMPattern(typeConv), enableIndexOptimizations(enableIndexOpt) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::CreateMaskOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = op->getResult(0).getType().cast(); int64_t rank = dstType.getRank(); @@ -648,19 +621,16 @@ const bool enableIndexOptimizations; }; -class VectorShuffleOpConversion : public ConvertToLLVMPattern { +class VectorShuffleOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorShuffleOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ShuffleOp shuffleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + auto loc = shuffleOp->getLoc(); auto adaptor = vector::ShuffleOpAdaptor(operands); - auto shuffleOp = cast(op); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); @@ -680,9 +650,9 @@ // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { - Value shuffle = rewriter.create( + Value llvmShuffleOp = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); - rewriter.replaceOp(op, shuffle); + rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); } @@ -701,23 +671,22 @@ insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } - rewriter.replaceOp(op, insert); + rewriter.replaceOp(shuffleOp, insert); return success(); } }; -class VectorExtractElementOpConversion : public ConvertToLLVMPattern { +class VectorExtractElementOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorExtractElementOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), - context, typeConverter) {} + using ConvertOpToLLVMPattern< + vector::ExtractElementOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ExtractElementOp extractEltOp, + ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpAdaptor(operands); - auto extractEltOp = cast(op); auto vectorType = extractEltOp.getVectorType(); auto llvmType = typeConverter->convertType(vectorType.getElementType()); @@ -726,24 +695,21 @@ return failure(); rewriter.replaceOpWithNewOp( - op, llvmType, adaptor.vector(), adaptor.position()); + extractEltOp, llvmType, adaptor.vector(), adaptor.position()); return success(); } }; -class VectorExtractOpConversion : public ConvertToLLVMPattern { +class VectorExtractOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorExtractOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + auto loc = extractOp->getLoc(); auto adaptor = vector::ExtractOpAdaptor(operands); - auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); @@ -757,12 +723,12 @@ if (resultType.isa()) { Value extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); - rewriter.replaceOp(op, extracted); + rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. - auto *context = op->getContext(); + auto *context = extractOp->getContext(); Value extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { @@ -780,7 +746,7 @@ auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); - rewriter.replaceOp(op, extracted); + rewriter.replaceOp(extractOp, extracted); return success(); } @@ -800,39 +766,32 @@ /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) /// -> !llvm<"<8 x float>"> /// ``` -class VectorFMAOp1DConversion : public ConvertToLLVMPattern { +class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern { public: - explicit VectorFMAOp1DConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::FMAOp fmaOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::FMAOpAdaptor(operands); - vector::FMAOp fmaOp = cast(op); VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), + rewriter.replaceOpWithNewOp(fmaOp, adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); } }; -class VectorInsertElementOpConversion : public ConvertToLLVMPattern { +class VectorInsertElementOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorInsertElementOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), - context, typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::InsertElementOp insertEltOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpAdaptor(operands); - auto insertEltOp = cast(op); auto vectorType = insertEltOp.getDestVectorType(); auto llvmType = typeConverter->convertType(vectorType); @@ -841,24 +800,22 @@ return failure(); rewriter.replaceOpWithNewOp( - op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); + insertEltOp, llvmType, adaptor.dest(), adaptor.source(), + adaptor.position()); return success(); } }; -class VectorInsertOpConversion : public ConvertToLLVMPattern { +class VectorInsertOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorInsertOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + auto loc = insertOp->getLoc(); auto adaptor = vector::InsertOpAdaptor(operands); - auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); @@ -873,12 +830,12 @@ Value inserted = rewriter.create( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); - rewriter.replaceOp(op, inserted); + rewriter.replaceOp(insertOp, inserted); return success(); } // Potential extraction of 1-D vector from array. - auto *context = op->getContext(); + auto *context = insertOp->getContext(); Value extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); @@ -908,7 +865,7 @@ nMinusOnePositionAttrs); } - rewriter.replaceOp(op, inserted); + rewriter.replaceOp(insertOp, inserted); return success(); } }; @@ -1117,18 +1074,15 @@ return strides; } -class VectorTypeCastOpConversion : public ConvertToLLVMPattern { +class VectorTypeCastOpConversion + : public ConvertOpToLLVMPattern { public: - explicit VectorTypeCastOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::TypeCastOp castOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - vector::TypeCastOp castOp = cast(op); + auto loc = castOp->getLoc(); MemRefType sourceMemRefType = castOp.getOperand().getType().cast(); MemRefType targetMemRefType = @@ -1195,7 +1149,7 @@ desc.setStride(rewriter, loc, index, stride); } - rewriter.replaceOp(op, {desc}); + rewriter.replaceOp(castOp, {desc}); return success(); } }; @@ -1208,18 +1162,16 @@ /// 4. Create a mask where offsetVector is compared against memref upper bound. /// 5. Rewrite op as a masked read or write. template -class VectorTransferConversion : public ConvertToLLVMPattern { +class VectorTransferConversion : public ConvertOpToLLVMPattern { public: - explicit VectorTransferConversion(MLIRContext *context, - LLVMTypeConverter &typeConv, + explicit VectorTransferConversion(LLVMTypeConverter &typeConv, bool enableIndexOpt) - : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), + : ConvertOpToLLVMPattern(typeConv), enableIndexOptimizations(enableIndexOpt) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto xferOp = cast(op); auto adaptor = getTransferOpAdapter(xferOp, operands); if (xferOp.getVectorType().getRank() > 1 || @@ -1228,16 +1180,18 @@ if (xferOp.permutation_map() != AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), xferOp.getVectorType().getRank(), - op->getContext())) + xferOp->getContext())) return failure(); // Only contiguous source tensors supported atm. auto strides = computeContiguousStrides(xferOp.getMemRefType()); if (!strides) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; + auto toLLVMTy = [&](Type t) { + return this->getTypeConverter()->convertType(t); + }; - Location loc = op->getLoc(); + Location loc = xferOp->getLoc(); MemRefType memRefType = xferOp.getMemRefType(); if (auto memrefVectorElementType = @@ -1267,8 +1221,8 @@ // addrspacecast shall be used when source/dst memrefs are not on // address space 0. // TODO: support alignment when possible. - Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = this->getStridedElementPtr( + loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); Value vectorDataPtr; @@ -1280,8 +1234,9 @@ loc, vecTy.getPointerTo(), dataPtr); if (!xferOp.isMaskedDim(0)) - return replaceTransferOpWithLoadOrStore( - rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr); + return replaceTransferOpWithLoadOrStore(rewriter, + *this->getTypeConverter(), loc, + xferOp, operands, vectorDataPtr); // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. @@ -1294,11 +1249,11 @@ unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; Value dim = rewriter.create(loc, xferOp.memref(), lastIndex); - Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, - vecWidth, dim, &off); + Value mask = buildVectorComparison( + rewriter, xferOp, enableIndexOptimizations, vecWidth, dim, &off); // 5. Rewrite as a masked read / write. - return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc, + return replaceTransferOpWithMasked(rewriter, *this->getTypeConverter(), loc, xferOp, operands, vectorDataPtr, mask); } @@ -1306,12 +1261,9 @@ const bool enableIndexOptimizations; }; -class VectorPrintOpConversion : public ConvertToLLVMPattern { +class VectorPrintOpConversion : public ConvertOpToLLVMPattern { public: - explicit VectorPrintOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, - typeConverter) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Proof-of-concept lowering implementation that relies on a small // runtime support library, which only needs to provide a few @@ -1326,9 +1278,8 @@ // TODO: rely solely on libc in future? something else? // LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(vector::PrintOp printOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto printOp = cast(op); auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); @@ -1341,11 +1292,11 @@ Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; if (eltType.isF32()) { - printer = getPrintFloat(op); + printer = getPrintFloat(printOp); } else if (eltType.isF64()) { - printer = getPrintDouble(op); + printer = getPrintDouble(printOp); } else if (eltType.isIndex()) { - printer = getPrintU64(op); + printer = getPrintU64(printOp); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1355,7 +1306,7 @@ if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = getPrintU64(op); + printer = getPrintU64(printOp); } else { return failure(); } @@ -1368,7 +1319,7 @@ conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = getPrintI64(op); + printer = getPrintI64(printOp); } else { return failure(); } @@ -1379,10 +1330,10 @@ // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; - emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, + emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank, conversion); - emitCall(rewriter, op->getLoc(), getPrintNewline(op)); - rewriter.eraseOp(op); + emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp)); + rewriter.eraseOp(printOp); return success(); } @@ -1560,11 +1511,11 @@ VectorInsertStridedSliceOpSameRankRewritePattern, VectorExtractStridedSliceOpConversion>(ctx); patterns.insert( - ctx, converter, reassociateFPReductions); + converter, reassociateFPReductions); patterns.insert, VectorTransferConversion>( - ctx, converter, enableIndexOptimizations); + converter, enableIndexOptimizations); patterns .insert(ctx, converter); + VectorCompressStoreOpConversion>(converter); // clang-format on } void mlir::populateVectorToLLVMMatrixConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - MLIRContext *ctx = converter.getDialect()->getContext(); - patterns.insert(ctx, converter); - patterns.insert(ctx, converter); + patterns.insert(converter); + patterns.insert(converter); } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -55,17 +55,13 @@ /// types. For unsupported cases, they will fall back to the vector to /// llvm conversion pattern. template -class VectorTransferConversion : public ConvertToLLVMPattern { +class VectorTransferConversion : public ConvertOpToLLVMPattern { public: - explicit VectorTransferConversion(MLIRContext *context, - LLVMTypeConverter &typeConv) - : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, - typeConv) {} + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(ConcreteOp xferOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto xferOp = cast(op); typename ConcreteOp::Adaptor adaptor(operands); if (xferOp.getVectorType().getRank() > 1 || @@ -79,11 +75,13 @@ if (!xferOp.isMaskedDim(0)) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; + auto toLLVMTy = [&](Type t) { + return this->getTypeConverter()->convertType(t); + }; LLVM::LLVMType vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); unsigned vecWidth = vecTy.getVectorNumElements(); - Location loc = op->getLoc(); + Location loc = xferOp->getLoc(); // The backend result vector scalarization have trouble scalarize // <1 x ty> result, exclude the x1 width from the lowering. @@ -102,8 +100,8 @@ // Note that the dataPtr starts at the offset address specified by // indices, so no need to calculate offset size in bytes again in // the MUBUF instruction. - Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = this->getStridedElementPtr( + loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. @@ -126,7 +124,7 @@ constConfig); Value dataPtrAsI64 = rewriter.create( loc, toLLVMTy(i64Ty).template cast(), dataPtr); - Value zero = createIndexConstant(rewriter, loc, 0); + Value zero = this->createIndexConstant(rewriter, loc, 0); Value dwordConfig = rewriter.create( loc, LLVM::LLVMType::getVectorTy( @@ -143,7 +141,7 @@ loc, toLLVMTy(i32Ty), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); return replaceTransferOpWithMubuf( - rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy, + rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy, dwordConfig, int32Zero, int32Zero, int1False, int1False); } }; @@ -151,9 +149,8 @@ void mlir::populateVectorToROCDLConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - MLIRContext *ctx = converter.getDialect()->getContext(); patterns.insert, - VectorTransferConversion>(ctx, converter); + VectorTransferConversion>(converter); } namespace {