diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -363,17 +363,15 @@ static unsigned getNumUnpackedValues() { return 2; } }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides -/// conversion patterns with an access to the containing LLVMLowering for the -/// purpose of type conversions. +/// conversion patterns with access to an LLVMTypeConverter. class LLVMOpLowering : public ConversionPattern { public: LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &lowering, PatternBenefit benefit = 1); + LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); protected: - // Back-reference to the lowering class, used to call type and function - // conversions accounting for potential extensions. - LLVMTypeConverter &lowering; + /// Reference to the type converter, with potential extensions. + LLVMTypeConverter &typeConverter; }; } // namespace mlir diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -51,7 +51,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto dialect = lowering.getDialect(); + auto dialect = typeConverter.getDialect(); Value newOp; switch (dimensionToIndex(cast(op))) { case X: diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -44,7 +44,7 @@ std::is_base_of, SourceOp>::value, "expected single result op"); - LLVMType resultType = lowering.convertType(op->getResult(0).getType()) + LLVMType resultType = typeConverter.convertType(op->getResult(0).getType()) .template cast(); LLVMType funcType = getFunctionType(resultType, operands); StringRef funcName = getFunctionName(resultType); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -337,7 +337,7 @@ // Clamp lane: `activeWidth - 1` Value maskAndClamp = rewriter.create(loc, int32Type, activeWidth, one); - auto dialect = lowering.getDialect(); + auto dialect = typeConverter.getDialect(); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); @@ -490,7 +490,7 @@ Location loc = op->getLoc(); gpu::ShuffleOpOperandAdaptor adaptor(operands); - auto dialect = lowering.getDialect(); + auto dialect = typeConverter.getDialect(); auto valueTy = adaptor.value().getType().cast(); auto int32Type = LLVM::LLVMType::getInt32Ty(dialect); auto predTy = LLVM::LLVMType::getInt1Ty(dialect); @@ -544,8 +544,8 @@ uint64_t numElements = type.getNumElements(); - auto elementType = - lowering.convertType(type.getElementType()).cast(); + auto elementType = typeConverter.convertType(type.getElementType()) + .cast(); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); @@ -557,15 +557,15 @@ } // Rewrite the original GPU function to an LLVM function. - auto funcType = lowering.convertType(gpuFuncOp.getType()) + auto funcType = typeConverter.convertType(gpuFuncOp.getType()) .cast() .getPointerElementTy(); // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); - lowering.convertFunctionSignature(gpuFuncOp.getType(), /*isVariadic=*/false, - signatureConversion); + typeConverter.convertFunctionSignature( + gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion); // Create the new function operation. Only copy those attributes that are // not specific to function modeling. @@ -592,7 +592,7 @@ // Rewrite workgroup memory attributions to addresses of global buffers. rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); - auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect()); + auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); Value zero = nullptr; if (!workgroupBuffers.empty()) @@ -612,15 +612,15 @@ // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution.getType().cast(); - auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, - type, memory); + auto descr = MemRefDescriptor::fromStaticShape( + rewriter, loc, typeConverter, type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); } // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); auto type = attribution.getType().cast(); @@ -630,7 +630,7 @@ // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). - auto ptrType = lowering.convertType(type.getElementType()) + auto ptrType = typeConverter.convertType(type.getElementType()) .cast() .getPointerTo(); Value numElements = rewriter.create( @@ -638,8 +638,8 @@ rewriter.getI64IntegerAttr(type.getNumElements())); Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); - auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering, - type, allocated); + auto descr = MemRefDescriptor::fromStaticShape( + rewriter, loc, typeConverter, type, allocated); signatureConversion.remapInput( numProperArguments + numWorkgroupAttributions + en.index(), descr); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -146,7 +146,7 @@ ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = - convertLinalgType(rangeOp.getResult().getType(), lowering); + convertLinalgType(rangeOp.getResult().getType(), typeConverter); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -190,7 +190,7 @@ edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpOperandAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.view()); - BaseViewConversionHelper desc(lowering.convertType(dstType)); + BaseViewConversionHelper desc(typeConverter.convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); desc.setAlignedPtr(baseDesc.alignedPtr()); desc.setOffset(baseDesc.offset()); @@ -225,11 +225,11 @@ auto sliceOp = cast(op); auto memRefType = sliceOp.getBaseViewType(); - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) + auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) .cast(); BaseViewConversionHelper desc( - lowering.convertType(sliceOp.getShapedType())); + typeConverter.convertType(sliceOp.getShapedType())); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); @@ -322,7 +322,7 @@ return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); BaseViewConversionHelper desc( - lowering.convertType(transposeOp.getShapedType())); + typeConverter.convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -376,9 +376,10 @@ } LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &lowering_, + LLVMTypeConverter &typeConverter_, PatternBenefit benefit) - : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} + : ConversionPattern(rootOpName, benefit, context), + typeConverter(typeConverter_) {} /*============================================================================*/ /* StructBuilder implementation */ @@ -706,9 +707,9 @@ public: // Construct a conversion pattern. explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, - LLVMTypeConverter &lowering_) + LLVMTypeConverter &typeConverter_) : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), - lowering_), + typeConverter_), dialect(dialect_) {} // Get the LLVM IR dialect. @@ -904,7 +905,7 @@ // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); - auto llvmType = lowering.convertFunctionSignature( + auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); // Propagate argument attributes to all converted arguments obtained after @@ -957,10 +958,10 @@ auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (emitWrappers) { if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, op->getLoc(), lowering, funcOp, + wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); else - wrapForExternalCallers(rewriter, op->getLoc(), lowering, funcOp, + wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); } @@ -1014,7 +1015,7 @@ rewriter.create(funcLoc, arg.getType()); rewriter.replaceUsesOfBlockArgument(arg, placeHolder); auto desc = MemRefDescriptor::fromStaticShape( - rewriter, funcLoc, lowering, memrefType, arg); + rewriter, funcLoc, typeConverter, memrefType, arg); rewriter.replaceOp(placeHolder.getDefiningOp(), {desc}); } } @@ -1119,7 +1120,8 @@ Type packedType; if (numResults != 0) { - packedType = this->lowering.packFunctionResults(op->getResultTypes()); + packedType = + this->typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return this->matchFailure(); } @@ -1139,7 +1141,7 @@ SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto type = this->lowering.convertType(op->getResult(i).getType()); + auto type = this->typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); @@ -1206,7 +1208,8 @@ auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return this->matchFailure(); - auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering); + auto vectorTypeInfo = + extractNDVectorTypeInfo(vectorType, this->typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return this->matchFailure(); @@ -1416,8 +1419,9 @@ // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto elementType = type.getElementType(); - auto convertedPtrType = - lowering.convertType(elementType).cast().getPointerTo(); + auto convertedPtrType = typeConverter.convertType(elementType) + .cast() + .getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto one = createIndexConstant(rewriter, loc, 1); auto gep = rewriter.create(loc, convertedPtrType, @@ -1464,7 +1468,7 @@ .getResult(0); } - auto structElementType = lowering.convertType(elementType); + auto structElementType = typeConverter.convertType(elementType); auto elementPtrType = structElementType.cast().getPointerTo( type.getMemorySpace()); Value bitcastAllocated = rewriter.create( @@ -1484,7 +1488,7 @@ "unexpected number of strides"); // Create the MemRef descriptor. - auto structType = lowering.convertType(type); + auto structType = typeConverter.convertType(type); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); @@ -1578,11 +1582,12 @@ } if (numResults != 0) { - if (!(packedResult = this->lowering.packFunctionResults(resultTypes))) + if (!(packedResult = + this->typeConverter.packFunctionResults(resultTypes))) return this->matchFailure(); } - auto promoted = this->lowering.promoteMemRefDescriptors( + auto promoted = this->typeConverter.promoteMemRefDescriptors( op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); @@ -1601,7 +1606,7 @@ SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto type = this->lowering.convertType(op->getResult(i).getType()); + auto type = this->typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); @@ -1749,7 +1754,7 @@ auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); - auto targetStructType = lowering.convertType(memRefCastOp.getType()); + auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); auto loc = op->getLoc(); if (srcType.isa() && dstType.isa()) { @@ -1766,15 +1771,15 @@ auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) - auto ptr = lowering.promoteOneMemRefDescriptor(loc, transformed.source(), - rewriter); + auto ptr = typeConverter.promoteOneMemRefDescriptor( + loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( - loc, lowering.convertType(rewriter.getIntegerType(64)), + loc, typeConverter.convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = @@ -1967,7 +1972,7 @@ transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. - auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); + auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); @@ -1998,7 +2003,7 @@ auto indexCastOp = cast(op); auto targetType = - this->lowering.convertType(indexCastOp.getResult().getType()) + this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); @@ -2033,7 +2038,7 @@ CmpIOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - op, lowering.convertType(cmpiOp.getResult().getType()), + op, typeConverter.convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -2052,7 +2057,7 @@ CmpFOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - op, lowering.convertType(cmpfOp.getResult().getType()), + op, typeConverter.convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -2138,8 +2143,8 @@ // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto packedType = - lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); + auto packedType = typeConverter.packFunctionResults( + llvm::to_vector<4>(op->getOperandTypes())); Value packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { @@ -2177,10 +2182,10 @@ return matchFailure(); // First insert it into an undef vector so we can shuffle it. - auto vectorType = lowering.convertType(splatOp.getType()); + auto vectorType = typeConverter.convertType(splatOp.getType()); Value undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( - op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), + op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( @@ -2213,7 +2218,7 @@ // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); - auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, lowering); + auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) @@ -2226,7 +2231,7 @@ // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( - loc, lowering.convertType(rewriter.getIntegerType(32)), + loc, typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); @@ -2278,14 +2283,15 @@ auto sourceMemRefType = viewOp.source().getType().cast(); auto sourceElementTy = - lowering.convertType(sourceMemRefType.getElementType()) + typeConverter.convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = viewOp.getType(); - auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) - .dyn_cast(); - auto targetDescTy = - lowering.convertType(viewMemRefType).dyn_cast_or_null(); + auto targetElementTy = + typeConverter.convertType(viewMemRefType.getElementType()) + .dyn_cast(); + auto targetDescTy = typeConverter.convertType(viewMemRefType) + .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return matchFailure(); @@ -2333,7 +2339,7 @@ strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Fill in missing dynamic sizes. - auto llvmIndexType = lowering.convertType(rewriter.getIndexType()); + auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); if (dynamicSizes.empty()) { dynamicSizes.reserve(viewMemRefType.getRank()); auto shape = viewMemRefType.getShape(); @@ -2424,10 +2430,11 @@ ViewOpOperandAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); - auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) - .dyn_cast(); + auto targetElementTy = + typeConverter.convertType(viewMemRefType.getElementType()) + .dyn_cast(); auto targetDescTy = - lowering.convertType(viewMemRefType).dyn_cast(); + typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), matchFailure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -36,8 +36,8 @@ template static LLVM::LLVMType getPtrToElementType(T containerType, - LLVMTypeConverter &lowering) { - return lowering.convertType(containerType.getElementType()) + LLVMTypeConverter &typeConverter) { + return typeConverter.convertType(containerType.getElementType()) .template cast() .getPointerTo(); } @@ -56,12 +56,13 @@ // Helper that picks the proper sequence for inserting. static Value insertOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value val1, - Value val2, Type llvmType, int64_t rank, int64_t pos) { + LLVMTypeConverter &typeConverter, Location loc, + Value val1, Value val2, Type llvmType, int64_t rank, + int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( - loc, lowering.convertType(idxType), + loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val1, val2, constant); @@ -83,12 +84,12 @@ // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering, Location loc, Value val, - Type llvmType, int64_t rank, int64_t pos) { + LLVMTypeConverter &typeConverter, Location loc, + Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( - loc, lowering.convertType(idxType), + loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val, constant); @@ -137,7 +138,7 @@ ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast(op); VectorType dstVectorType = broadcastOp.getVectorType(); - if (lowering.convertType(dstVectorType) == nullptr) + if (typeConverter.convertType(dstVectorType) == nullptr) return matchFailure(); // Rewrite when the full vector type can be lowered (which // implies all 'reduced' types can be lowered too). @@ -203,12 +204,12 @@ Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType, VectorType dstVectorType, int64_t rank, int64_t dim, ConversionPatternRewriter &rewriter) const { - Type llvmType = lowering.convertType(dstVectorType); + Type llvmType = typeConverter.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); if (rank == 1) { Value undef = rewriter.create(loc, llvmType); - Value expand = - insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0); + Value expand = insertOne(rewriter, typeConverter, loc, undef, value, + llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues)); @@ -217,8 +218,8 @@ reducedVectorTypeFront(dstVectorType), rewriter); Value result = rewriter.create(loc, llvmType); for (int64_t d = 0; d < dim; ++d) { - result = - insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); + result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, + rank, d); } return result; } @@ -243,31 +244,32 @@ Value stretchOneRank(Value value, Location loc, VectorType srcVectorType, VectorType dstVectorType, int64_t rank, int64_t dim, ConversionPatternRewriter &rewriter) const { - Type llvmType = lowering.convertType(dstVectorType); + Type llvmType = typeConverter.convertType(dstVectorType); assert((llvmType != nullptr) && "unlowerable vector type"); Value result = rewriter.create(loc, llvmType); bool atStretch = dim != srcVectorType.getDimSize(0); if (rank == 1) { assert(atStretch); - Type redLlvmType = lowering.convertType(dstVectorType.getElementType()); + Type redLlvmType = + typeConverter.convertType(dstVectorType.getElementType()); Value one = - extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0); - Value expand = - insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0); + extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0); + Value expand = insertOne(rewriter, typeConverter, loc, result, one, + llvmType, rank, 0); SmallVector zeroValues(dim, 0); return rewriter.create( loc, expand, result, rewriter.getI32ArrayAttr(zeroValues)); } VectorType redSrcType = reducedVectorTypeFront(srcVectorType); VectorType redDstType = reducedVectorTypeFront(dstVectorType); - Type redLlvmType = lowering.convertType(redSrcType); + Type redLlvmType = typeConverter.convertType(redSrcType); for (int64_t d = 0; d < dim; ++d) { int64_t pos = atStretch ? 0 : d; - Value one = - extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos); + Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType, + rank, pos); Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter); - result = - insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d); + result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType, + rank, d); } return result; } @@ -286,7 +288,7 @@ auto reductionOp = cast(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); - Type llvmType = lowering.convertType(eltType); + Type llvmType = typeConverter.convertType(eltType); if (eltType.isInteger(32) || eltType.isInteger(64)) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") @@ -353,7 +355,7 @@ auto reductionOp = cast(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); - Type llvmType = lowering.convertType(eltType); + Type llvmType = typeConverter.convertType(eltType); if (kind == "add") { rewriter.replaceOpWithNewOp( op, llvmType, operands[1], operands[0]); @@ -383,7 +385,7 @@ auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); - Type llvmType = lowering.convertType(vectorType); + Type llvmType = typeConverter.convertType(vectorType); auto maskArrayAttr = shuffleOp.mask(); // Bail if result type cannot be lowered. @@ -415,10 +417,10 @@ extPos -= v1Dim; value = adaptor.v2(); } - Value extract = - extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos); - insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType, - rank, insPos++); + Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, + rank, extPos); + insert = insertOne(rewriter, typeConverter, loc, insert, extract, + llvmType, rank, insPos++); } rewriter.replaceOp(op, insert); return matchSuccess(); @@ -438,7 +440,7 @@ auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); auto extractEltOp = cast(op); auto vectorType = extractEltOp.getVectorType(); - auto llvmType = lowering.convertType(vectorType.getElementType()); + auto llvmType = typeConverter.convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. if (!llvmType) @@ -465,7 +467,7 @@ auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); - auto llvmResultType = lowering.convertType(resultType); + auto llvmResultType = typeConverter.convertType(resultType); auto positionArrayAttr = extractOp.position(); // Bail if result type cannot be lowered. @@ -489,13 +491,13 @@ auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( - loc, lowering.convertType(oneDVectorType), extracted, + loc, typeConverter.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); - auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); @@ -553,7 +555,7 @@ auto adaptor = vector::InsertElementOpOperandAdaptor(operands); auto insertEltOp = cast(op); auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = lowering.convertType(vectorType); + auto llvmType = typeConverter.convertType(vectorType); // Bail if result type cannot be lowered. if (!llvmType) @@ -580,7 +582,7 @@ auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); - auto llvmResultType = lowering.convertType(destVectorType); + auto llvmResultType = typeConverter.convertType(destVectorType); auto positionArrayAttr = insertOp.position(); // Bail if result type cannot be lowered. @@ -607,16 +609,16 @@ auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( - loc, lowering.convertType(oneDVectorType), extracted, + loc, typeConverter.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Insertion of an element into a 1-D LLVM vector. - auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( - loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(), - constant); + loc, typeConverter.convertType(oneDVectorType), extracted, + adaptor.source(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { @@ -830,7 +832,7 @@ auto vRHS = adaptor.rhs().getType().cast(); auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); - auto llvmArrayOfVectType = lowering.convertType( + auto llvmArrayOfVectType = typeConverter.convertType( cast(op).getResult().getType()); Value desc = rewriter.create(loc, llvmArrayOfVectType); Value a = adaptor.lhs(), b = adaptor.rhs(); @@ -893,7 +895,7 @@ return matchFailure(); MemRefDescriptor sourceMemRef(operands[0]); - auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType) + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return matchFailure(); @@ -916,7 +918,7 @@ if (failed(successStrides) || !isContiguous) return matchFailure(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); + auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); @@ -979,7 +981,7 @@ auto adaptor = vector::PrintOpOperandAdaptor(operands); Type printType = printOp.getPrintType(); - if (lowering.convertType(printType) == nullptr) + if (typeConverter.convertType(printType) == nullptr) return matchFailure(); // Make sure element type has runtime support (currently just Float/Double). @@ -1021,10 +1023,10 @@ for (int64_t d = 0; d < dim; ++d) { auto reducedType = rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; - auto llvmType = lowering.convertType( + auto llvmType = typeConverter.convertType( rank > 1 ? reducedType : vectorType.getElementType()); Value nestedVal = - extractOne(rewriter, lowering, loc, value, llvmType, rank, d); + extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); if (d != dim - 1) emitCall(rewriter, loc, printComma); @@ -1055,36 +1057,36 @@ // Helpers for method names. Operation *getPrintI32(Operation *op) const { - LLVM::LLVMDialect *dialect = lowering.getDialect(); + LLVM::LLVMDialect *dialect = typeConverter.getDialect(); return getPrint(op, dialect, "print_i32", LLVM::LLVMType::getInt32Ty(dialect)); } Operation *getPrintI64(Operation *op) const { - LLVM::LLVMDialect *dialect = lowering.getDialect(); + LLVM::LLVMDialect *dialect = typeConverter.getDialect(); return getPrint(op, dialect, "print_i64", LLVM::LLVMType::getInt64Ty(dialect)); } Operation *getPrintFloat(Operation *op) const { - LLVM::LLVMDialect *dialect = lowering.getDialect(); + LLVM::LLVMDialect *dialect = typeConverter.getDialect(); return getPrint(op, dialect, "print_f32", LLVM::LLVMType::getFloatTy(dialect)); } Operation *getPrintDouble(Operation *op) const { - LLVM::LLVMDialect *dialect = lowering.getDialect(); + LLVM::LLVMDialect *dialect = typeConverter.getDialect(); return getPrint(op, dialect, "print_f64", LLVM::LLVMType::getDoubleTy(dialect)); } Operation *getPrintOpen(Operation *op) const { - return getPrint(op, lowering.getDialect(), "print_open", {}); + return getPrint(op, typeConverter.getDialect(), "print_open", {}); } Operation *getPrintClose(Operation *op) const { - return getPrint(op, lowering.getDialect(), "print_close", {}); + return getPrint(op, typeConverter.getDialect(), "print_close", {}); } Operation *getPrintComma(Operation *op) const { - return getPrint(op, lowering.getDialect(), "print_comma", {}); + return getPrint(op, typeConverter.getDialect(), "print_comma", {}); } Operation *getPrintNewline(Operation *op) const { - return getPrint(op, lowering.getDialect(), "print_newline", {}); + return getPrint(op, typeConverter.getDialect(), "print_newline", {}); } };