diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -71,7 +71,7 @@ /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. - LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, + LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a @@ -485,6 +485,8 @@ /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; + LLVMTypeConverter *getTypeConverter() const; + /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. LLVM::LLVMType getIndexType() const; @@ -556,10 +558,6 @@ Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const; - -protected: - /// Reference to the type converter, with potential extensions. - LLVMTypeConverter &typeConverter; }; /// Utility class for operation conversions targeting the LLVM dialect that @@ -644,7 +642,7 @@ matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - operands, this->typeConverter, + operands, *this->getTypeConverter(), rewriter); } }; @@ -666,9 +664,9 @@ static_assert(std::is_base_of, SourceOp>::value, "expected same operands and result type"); - return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(), - operands, this->typeConverter, - rewriter); + return LLVM::detail::vectorOneToOneRewrite( + op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), + rewriter); } }; diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -86,7 +86,7 @@ return failure(); return matchAndRewriteOneToOne( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -103,7 +103,7 @@ return failure(); return matchAndRewriteOneToOne( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -120,7 +120,7 @@ return failure(); return matchAndRewriteOneToOne( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; @@ -137,7 +137,7 @@ return failure(); return matchAndRewriteOneToOne( - *this, this->typeConverter, op, operands, rewriter); + *this, *getTypeConverter(), op, operands, rewriter); } }; } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -72,7 +72,7 @@ : ConvertOpToLLVMPattern(typeConverter) {} protected: - MLIRContext *context = &this->typeConverter.getContext(); + MLIRContext *context = &this->getTypeConverter()->getContext(); LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context); LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context); @@ -81,7 +81,7 @@ LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context); LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context); LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy( - context, this->typeConverter.getPointerBitwidth(0)); + context, this->getTypeConverter()->getPointerBitwidth(0)); FunctionCallBuilder moduleLoadCallBuilder = { "mgpuModuleLoad", @@ -333,8 +333,8 @@ auto elementType = memRefType.cast().getElementType(); auto elementSize = getSizeInBytes(loc, elementType, rewriter); - auto arguments = - typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter); + auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(), + operands, rewriter); arguments.push_back(elementSize); hostRegisterCallBuilder.create(loc, rewriter, arguments); @@ -486,7 +486,7 @@ OpBuilder &builder) const { auto loc = launchOp.getLoc(); auto numKernelOperands = launchOp.getNumKernelOperands(); - auto arguments = typeConverter.promoteOperands( + auto arguments = getTypeConverter()->promoteOperands( loc, launchOp.getOperands().take_back(numKernelOperands), operands.take_back(numKernelOperands), builder); auto numArguments = arguments.size(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -41,7 +41,7 @@ uint64_t numElements = type.getNumElements(); - auto elementType = typeConverter.convertType(type.getElementType()) + auto elementType = typeConverter->convertType(type.getElementType()) .template cast(); auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); std::string name = std::string( @@ -54,14 +54,14 @@ } // Rewrite the original GPU function to an LLVM function. - auto funcType = typeConverter.convertType(gpuFuncOp.getType()) + auto funcType = typeConverter->convertType(gpuFuncOp.getType()) .template cast() .getPointerElementTy(); // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); - typeConverter.convertFunctionSignature( + getTypeConverter()->convertFunctionSignature( gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion); // Create the new function operation. Only copy those attributes that are @@ -110,7 +110,7 @@ Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution.getType().cast(); auto descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, type, memory); + rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); } @@ -127,7 +127,7 @@ // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). - auto ptrType = typeConverter.convertType(type.getElementType()) + auto ptrType = typeConverter->convertType(type.getElementType()) .template cast() .getPointerTo(AllocaAddrSpace); Value numElements = rewriter.create( @@ -136,7 +136,7 @@ Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); auto descr = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, type, allocated); + rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( numProperArguments + numWorkgroupAttributions + en.index(), descr); } @@ -145,8 +145,8 @@ // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter, - &signatureConversion))) + if (failed(rewriter.convertRegionTypes( + &llvmFuncOp.getBody(), *typeConverter, &signatureConversion))) return failure(); rewriter.eraseOp(gpuFuncOp); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -135,8 +135,8 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); - auto rangeDescriptorTy = - convertRangeType(rangeOp.getType().cast(), typeConverter); + auto rangeDescriptorTy = convertRangeType( + rangeOp.getType().cast(), *getTypeConverter()); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -181,7 +181,7 @@ edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.src()); - BaseViewConversionHelper desc(typeConverter.convertType(dstType)); + BaseViewConversionHelper desc(typeConverter->convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); desc.setAlignedPtr(baseDesc.alignedPtr()); desc.setOffset(baseDesc.offset()); @@ -214,11 +214,11 @@ auto sliceOp = cast(op); auto memRefType = sliceOp.getBaseViewType(); - auto int64Ty = typeConverter.convertType(rewriter.getIntegerType(64)) + auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)) .cast(); BaseViewConversionHelper desc( - typeConverter.convertType(sliceOp.getShapedType())); + typeConverter->convertType(sliceOp.getShapedType())); // TODO: extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -35,7 +35,7 @@ curOp.getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); - if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) + if (failed(rewriter.convertRegionTypes(&newOp.region(), *typeConverter))) return failure(); rewriter.eraseOp(op); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -224,7 +224,7 @@ spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; auto pointeeType = spirvGlobal.type().cast().getPointeeType(); - auto dstGlobalType = typeConverter.convertType(pointeeType); + auto dstGlobalType = typeConverter->convertType(pointeeType); if (!dstGlobalType) return failure(); std::string name = diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -446,8 +446,7 @@ MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) - : ConversionPattern(rootOpName, benefit, typeConverter, context), - typeConverter(typeConverter) {} + : ConversionPattern(rootOpName, benefit, typeConverter, context) {} //===----------------------------------------------------------------------===// // StructBuilder implementation @@ -1013,27 +1012,32 @@ builder.create(loc, stride, strideStoreGep); } +LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { + return static_cast( + ConversionPattern::getTypeConverter()); +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { - return *typeConverter.getDialect(); + return *getTypeConverter()->getDialect(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { - return typeConverter.getIndexType(); + return getTypeConverter()->getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMType::getIntNTy( - &typeConverter.getContext(), - typeConverter.getPointerBitwidth(addressSpace)); + &getTypeConverter()->getContext(), + getTypeConverter()->getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { - return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); + return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { - return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext()); + return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext()); } Value ConvertToLLVMPattern::createIndexConstant( @@ -1086,7 +1090,7 @@ // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { - if (!typeConverter.convertType(type.getElementType())) + if (!typeConverter->convertType(type.getElementType())) return false; return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), @@ -1095,7 +1099,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); - auto structElementType = unwrap(typeConverter.convertType(elementType)); + auto structElementType = unwrap(typeConverter->convertType(elementType)); return structElementType.getPointerTo(type.getMemorySpace()); } @@ -1155,7 +1159,7 @@ // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto convertedPtrType = - typeConverter.convertType(type).cast().getPointerTo(); + typeConverter->convertType(type).cast().getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, @@ -1179,7 +1183,7 @@ Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const { - auto structType = typeConverter.convertType(memRefType); + auto structType = typeConverter->convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. @@ -1347,7 +1351,7 @@ // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); - auto llvmType = typeConverter.convertFunctionSignature( + auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; @@ -1379,7 +1383,7 @@ attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); - if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, + if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; @@ -1402,14 +1406,14 @@ if (!newFuncOp) return failure(); - if (typeConverter.getOptions().emitCWrappers || + if (getTypeConverter()->getOptions().emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) - wrapExternalFunction(rewriter, funcOp.getLoc(), typeConverter, funcOp, - newFuncOp); + wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); else - wrapForExternalCallers(rewriter, funcOp.getLoc(), typeConverter, funcOp, - newFuncOp); + wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), + funcOp, newFuncOp); } rewriter.eraseOp(funcOp); @@ -1472,7 +1476,7 @@ rewriter.replaceUsesOfBlockArgument(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( - rewriter, loc, typeConverter, memrefTy, arg); + rewriter, loc, *getTypeConverter(), memrefTy, arg); rewriter.replaceOp(placeholder, {desc}); } @@ -1757,7 +1761,7 @@ // Pack real and imaginary part in a complex number struct. auto loc = op.getLoc(); - auto structType = typeConverter.convertType(complexOp.getType()); + auto structType = typeConverter->convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); @@ -1836,7 +1840,7 @@ unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. - auto structType = this->typeConverter.convertType(op.getType()); + auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. @@ -1863,7 +1867,7 @@ unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. - auto structType = this->typeConverter.convertType(op.getType()); + auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. @@ -1887,7 +1891,7 @@ ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { - auto type = typeConverter.convertType(op.getResult().getType()) + auto type = typeConverter->convertType(op.getResult().getType()) .dyn_cast_or_null(); if (!type) return rewriter.notifyMatchFailure(op, "failed to convert result type"); @@ -1905,9 +1909,9 @@ return rewriter.notifyMatchFailure( op, "referring to a symbol outside of the current module"); - return LLVM::detail::oneToOneRewrite(op, - LLVM::ConstantOp::getOperationName(), - operands, typeConverter, rewriter); + return LLVM::detail::oneToOneRewrite( + op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(), + rewriter); } }; @@ -1916,7 +1920,6 @@ using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; - using ConvertToLLVMPattern::typeConverter; explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} @@ -2288,11 +2291,11 @@ if (numResults != 0) { if (!(packedResult = - this->typeConverter.packFunctionResults(resultTypes))) + this->getTypeConverter()->packFunctionResults(resultTypes))) return failure(); } - auto promoted = this->typeConverter.promoteOperands( + auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands, rewriter); auto newOp = rewriter.create( @@ -2309,23 +2312,23 @@ results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = - this->typeConverter.convertType(callOp.getResult(i).getType()); + this->typeConverter->convertType(callOp.getResult(i).getType()); results.push_back(rewriter.create( callOp.getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } } - if (this->typeConverter.getOptions().useBarePtrCallConv) { + if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); - this->typeConverter.promoteBarePtrsToDescriptors( + this->getTypeConverter()->promoteBarePtrsToDescriptors( rewriter, callOp.getLoc(), resultTypes, results); } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(), - this->typeConverter, resultTypes, - results, + *this->getTypeConverter(), + resultTypes, results, /*toDynamic=*/false))) { return failure(); } @@ -2410,7 +2413,8 @@ if (!isSupportedMemRefType(type)) return failure(); - LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + LLVM::LLVMType arrayTy = + convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; @@ -2449,14 +2453,15 @@ MemRefType type = getGlobalOp.result().getType().cast(); unsigned memSpace = type.getMemorySpace(); - LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); + LLVM::LLVMType arrayTy = + convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. LLVM::LLVMType elementType = - unwrap(typeConverter.convertType(type.getElementType())); + unwrap(typeConverter->convertType(type.getElementType())); LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); SmallVector operands = {addressOf}; @@ -2517,7 +2522,7 @@ return failure(); return handleMultidimensionalVectors( - op.getOperation(), operands, typeConverter, + op.getOperation(), operands, *getTypeConverter(), [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, @@ -2546,8 +2551,8 @@ // a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa() && dstType.isa()) - return success(typeConverter.convertType(srcType) == - typeConverter.convertType(dstType)); + return success(typeConverter->convertType(srcType) == + typeConverter->convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa() || @@ -2566,7 +2571,7 @@ auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); - auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); + auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. @@ -2581,7 +2586,7 @@ auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) - auto ptr = typeConverter.promoteOneMemRefDescriptor( + auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = @@ -2589,7 +2594,7 @@ .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( - loc, typeConverter.convertType(rewriter.getIntegerType(64)), + loc, typeConverter->convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = @@ -2693,7 +2698,7 @@ Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast(); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); @@ -2704,8 +2709,9 @@ // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; - extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), - adaptor.source(), &allocatedPtr, &alignedPtr); + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + castOp.source(), adaptor.source(), &allocatedPtr, + &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); @@ -2779,10 +2785,10 @@ // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( - rewriter, loc, unwrap(typeConverter.convertType(targetType))); + rewriter, loc, unwrap(typeConverter->convertType(targetType))); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), sizes.front(), llvm::None); @@ -2790,37 +2796,38 @@ // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; - extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), - adaptor.source(), &allocatedPtr, &alignedPtr, - &offset); + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + reshapeOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. LLVM::LLVMType llvmElementType = - unwrap(typeConverter.convertType(elementType)); + unwrap(typeConverter->convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(addressSpace).getPointerTo(); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); - UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, alignedPtr); - UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( - rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + rewriter, loc, *getTypeConverter(), underlyingDescPtr, + elementPtrPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( - rewriter, loc, typeConverter, targetSizesBase, resultRank); + rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexConstant(rewriter, loc, 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); - LLVM::LLVMType indexType = typeConverter.getIndexType(); + LLVM::LLVMType indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, @@ -2854,11 +2861,11 @@ Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); - UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. - UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, + UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); Value nextStride = rewriter.create(loc, strideArg, size); @@ -2892,7 +2899,7 @@ ConversionPatternRewriter &rewriter) const override { LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != - typeConverter.convertType(castOp.getType())) { + typeConverter->convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(castOp, transformed.in()); @@ -2942,15 +2949,16 @@ Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, - typeConverter.convertType(scalarMemRefType) + typeConverter->convertType(scalarMemRefType) .cast() .getPointerTo(addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. - Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace); + Type indexPtrTy = + getTypeConverter()->getIndexType().getPointerTo(addressSpace); Value two = rewriter.create( - loc, typeConverter.convertType(rewriter.getI32Type()), + loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create( loc, indexPtrTy, scalarMemRefDescPtr, @@ -3082,7 +3090,7 @@ transformed.indices(), rewriter); // Replace with llvm.prefetch. - auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); + auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( @@ -3110,7 +3118,7 @@ IndexCastOpAdaptor transformed(operands); auto targetType = - this->typeConverter.convertType(indexCastOp.getResult().getType()) + typeConverter->convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getIntegerBitWidth(); @@ -3144,7 +3152,7 @@ CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - cmpiOp, typeConverter.convertType(cmpiOp.getResult().getType()), + cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3162,7 +3170,7 @@ CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( - cmpfOp, typeConverter.convertType(cmpfOp.getResult().getType()), + cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); @@ -3248,7 +3256,7 @@ unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; - if (typeConverter.getOptions().useBarePtrCallConv) { + if (getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), operands)) { @@ -3266,7 +3274,7 @@ } } else { updatedOperands = llvm::to_vector<4>(operands); - copyUnrankedDescriptors(rewriter, loc, typeConverter, + copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(), op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } @@ -3285,7 +3293,7 @@ // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. - auto packedType = typeConverter.packFunctionResults( + auto packedType = getTypeConverter()->packFunctionResults( llvm::to_vector<4>(op.getOperandTypes())); Value packed = rewriter.create(loc, packedType); @@ -3323,11 +3331,11 @@ return failure(); // First insert it into an undef vector so we can shuffle it. - auto vectorType = typeConverter.convertType(splatOp.getType()); + auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create(splatOp.getLoc(), vectorType); auto zero = rewriter.create( splatOp.getLoc(), - typeConverter.convertType(rewriter.getIntegerType(32)), + typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( @@ -3360,7 +3368,8 @@ // First insert it into an undef vector so we can shuffle it. auto loc = splatOp.getLoc(); - auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); + auto vectorTypeInfo = + extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) @@ -3373,7 +3382,7 @@ // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( - loc, typeConverter.convertType(rewriter.getIntegerType(32)), + loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); @@ -3418,7 +3427,7 @@ auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = - typeConverter.convertType(sourceMemRefType.getElementType()) + typeConverter->convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); @@ -3429,9 +3438,9 @@ extractFromI64ArrayAttr(subViewOp.static_strides())) .cast(); auto targetElementTy = - typeConverter.convertType(viewMemRefType.getElementType()) + typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast(); - auto targetDescTy = typeConverter.convertType(viewMemRefType) + auto targetDescTy = typeConverter->convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return failure(); @@ -3477,7 +3486,7 @@ strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. - auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); + auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { @@ -3553,7 +3562,7 @@ return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( - rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); + rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. @@ -3629,10 +3638,10 @@ auto viewMemRefType = viewOp.getType(); auto targetElementTy = - typeConverter.convertType(viewMemRefType.getElementType()) + typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = - typeConverter.convertType(viewMemRefType).dyn_cast(); + typeConverter->convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); @@ -3825,7 +3834,7 @@ auto loc = atomicOp.getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = - typeConverter.convertType(atomicOp.getResult().getType()) + typeConverter->convertType(atomicOp.getResult().getType()) .cast(); // Split the block into initial, loop, and ending parts. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -309,7 +309,7 @@ auto matmulOp = cast(op); auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp( - op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), + op, typeConverter->convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), matmulOp.rhs_columns()); return success(); @@ -331,7 +331,7 @@ auto transOp = cast(op); auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp( - transOp, typeConverter.convertType(transOp.res().getType()), + transOp, typeConverter->convertType(transOp.res().getType()), adaptor.matrix(), transOp.rows(), transOp.columns()); return success(); } @@ -354,10 +354,10 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, load, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), load, align))) return failure(); - auto vtype = typeConverter.convertType(load.getResultVectorType()); + auto vtype = typeConverter->convertType(load.getResultVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), vtype, ptr))) @@ -387,10 +387,10 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, store, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), store, align))) return failure(); - auto vtype = typeConverter.convertType(store.getValueVectorType()); + auto vtype = typeConverter->convertType(store.getValueVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), vtype, ptr))) @@ -420,7 +420,7 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, gather, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), gather, align))) return failure(); // Get index ptrs. @@ -433,7 +433,7 @@ // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( - gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), + gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); return success(); } @@ -456,7 +456,7 @@ // Resolve alignment. unsigned align; - if (failed(getMemRefAlignment(typeConverter, scatter, align))) + if (failed(getMemRefAlignment(*getTypeConverter(), scatter, align))) return failure(); // Get index ptrs. @@ -497,7 +497,7 @@ auto vType = expand.getResultVectorType(); rewriter.replaceOpWithNewOp( - op, typeConverter.convertType(vType), ptr, adaptor.mask(), + op, typeConverter->convertType(vType), ptr, adaptor.mask(), adaptor.pass_thru()); return success(); } @@ -545,7 +545,7 @@ auto reductionOp = cast(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); - Type llvmType = typeConverter.convertType(eltType); + Type llvmType = typeConverter->convertType(eltType); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") @@ -580,39 +580,40 @@ else return failure(); return success(); - - } else if (eltType.isa()) { - // Floating-point reductions: add/mul/min/max - if (kind == "add") { - // Optional accumulator (or zero). - Value acc = operands.size() > 1 ? operands[1] - : rewriter.create( - op->getLoc(), llvmType, - rewriter.getZeroAttr(eltType)); - rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0], - rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == "mul") { - // Optional accumulator (or one). - Value acc = operands.size() > 1 - ? operands[1] - : rewriter.create( - op->getLoc(), llvmType, - rewriter.getFloatAttr(eltType, 1.0)); - rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0], - rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == "min") - rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); - else if (kind == "max") - rewriter.replaceOpWithNewOp( - op, llvmType, operands[0]); - else - return failure(); - return success(); } - return failure(); + + if (!eltType.isa()) + return failure(); + + // Floating-point reductions: add/mul/min/max + if (kind == "add") { + // Optional accumulator (or zero). + Value acc = operands.size() > 1 ? operands[1] + : rewriter.create( + op->getLoc(), llvmType, + rewriter.getZeroAttr(eltType)); + rewriter.replaceOpWithNewOp( + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); + } else if (kind == "mul") { + // Optional accumulator (or one). + Value acc = operands.size() > 1 + ? operands[1] + : rewriter.create( + op->getLoc(), llvmType, + rewriter.getFloatAttr(eltType, 1.0)); + rewriter.replaceOpWithNewOp( + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); + } else if (kind == "min") + rewriter.replaceOpWithNewOp(op, llvmType, + operands[0]); + else if (kind == "max") + rewriter.replaceOpWithNewOp(op, llvmType, + operands[0]); + else + return failure(); + return success(); } private: @@ -663,7 +664,7 @@ auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); - Type llvmType = typeConverter.convertType(vectorType); + Type llvmType = typeConverter->convertType(vectorType); auto maskArrayAttr = shuffleOp.mask(); // Bail if result type cannot be lowered. @@ -695,9 +696,9 @@ extPos -= v1Dim; value = adaptor.v2(); } - Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, - rank, extPos); - insert = insertOne(rewriter, typeConverter, loc, insert, extract, + Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, + llvmType, rank, extPos); + insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract, llvmType, rank, insPos++); } rewriter.replaceOp(op, insert); @@ -718,7 +719,7 @@ auto adaptor = vector::ExtractElementOpAdaptor(operands); auto extractEltOp = cast(op); auto vectorType = extractEltOp.getVectorType(); - auto llvmType = typeConverter.convertType(vectorType.getElementType()); + auto llvmType = typeConverter->convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. if (!llvmType) @@ -745,7 +746,7 @@ auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); - auto llvmResultType = typeConverter.convertType(resultType); + auto llvmResultType = typeConverter->convertType(resultType); auto positionArrayAttr = extractOp.position(); // Bail if result type cannot be lowered. @@ -769,7 +770,7 @@ auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } @@ -833,7 +834,7 @@ auto adaptor = vector::InsertElementOpAdaptor(operands); auto insertEltOp = cast(op); auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter.convertType(vectorType); + auto llvmType = typeConverter->convertType(vectorType); // Bail if result type cannot be lowered. if (!llvmType) @@ -860,7 +861,7 @@ auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); - auto llvmResultType = typeConverter.convertType(destVectorType); + auto llvmResultType = typeConverter->convertType(destVectorType); auto positionArrayAttr = insertOp.position(); // Bail if result type cannot be lowered. @@ -887,7 +888,7 @@ auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } @@ -895,7 +896,7 @@ auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( - loc, typeConverter.convertType(oneDVectorType), extracted, + loc, typeConverter->convertType(oneDVectorType), extracted, adaptor.source(), constant); // Potential insertion of resulting 1-D vector into array. @@ -1000,7 +1001,7 @@ Value extracted = rewriter.create(loc, op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropFront=*/rankRest)); + /*dropBack=*/rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. auto stridedSliceInnerOp = rewriter.create( @@ -1010,7 +1011,7 @@ rewriter.replaceOpWithNewOp( op, stridedSliceInnerOp.getResult(), op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropFront=*/rankRest)); + /*dropBack=*/rankRest)); return success(); } }; @@ -1144,7 +1145,7 @@ return failure(); MemRefDescriptor sourceMemRef(operands[0]); - auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); @@ -1234,7 +1235,7 @@ if (!strides) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; Location loc = op->getLoc(); MemRefType memRefType = xferOp.getMemRefType(); @@ -1279,8 +1280,8 @@ loc, vecTy.getPointerTo(), dataPtr); if (!xferOp.isMaskedDim(0)) - return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, - xferOp, operands, vectorDataPtr); + return replaceTransferOpWithLoadOrStore( + rewriter, *getTypeConverter(), loc, xferOp, operands, vectorDataPtr); // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. @@ -1297,8 +1298,8 @@ vecWidth, dim, &off); // 5. Rewrite as a masked read / write. - return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, - operands, vectorDataPtr, mask); + return replaceTransferOpWithMasked(rewriter, *getTypeConverter(), loc, + xferOp, operands, vectorDataPtr, mask); } private: @@ -1331,7 +1332,7 @@ auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); - if (typeConverter.convertType(printType) == nullptr) + if (typeConverter->convertType(printType) == nullptr) return failure(); // Make sure element type has runtime support. @@ -1421,10 +1422,10 @@ for (int64_t d = 0; d < dim; ++d) { auto reducedType = rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; - auto llvmType = typeConverter.convertType( + auto llvmType = typeConverter->convertType( rank > 1 ? reducedType : vectorType.getElementType()); - Value nestedVal = - extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); + Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value, + llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, conversion); if (d != dim - 1) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -79,7 +79,7 @@ if (!xferOp.isMaskedDim(0)) return failure(); - auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; + auto toLLVMTy = [&](Type t) { return typeConverter->convertType(t); }; LLVM::LLVMType vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); unsigned vecWidth = vecTy.getVectorNumElements(); @@ -142,9 +142,9 @@ Value int32Zero = rewriter.create( loc, toLLVMTy(i32Ty), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); - return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc, - xferOp, vecTy, dwordConfig, int32Zero, - int32Zero, int1False, int1False); + return replaceTransferOpWithMubuf( + rewriter, operands, *getTypeConverter(), loc, xferOp, vecTy, + dwordConfig, int32Zero, int32Zero, int1False, int1False); } }; } // end anonymous namespace