diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -522,6 +522,9 @@ ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const; + /// Returns if the givem memref type is supported. + bool isSupportedMemRefType(MemRefType type) const; + Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1094,11 +1094,20 @@ offset, rewriter); } +// Check if the MemRefType `type` is supported by the lowering. We currently +// only support memrefs with identity maps. +bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { + if (!typeConverter.convertType(type.getElementType())) + return false; + return type.getAffineMaps().empty() || + llvm::all_of(type.getAffineMaps(), + [](AffineMap map) { return map.isIdentity(); }); +} + Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); - auto structElementType = typeConverter.convertType(elementType); - return structElementType.cast().getPointerTo( - type.getMemorySpace()); + auto structElementType = unwrap(typeConverter.convertType(elementType)); + return structElementType.getPointerTo(type.getMemorySpace()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( @@ -1912,14 +1921,6 @@ } }; -// Check if the MemRefType `type` is supported by the lowering. We currently -// only support memrefs with identity maps. -static bool isSupportedMemRefType(MemRefType type) { - return type.getAffineMaps().empty() || - llvm::all_of(type.getAffineMaps(), - [](AffineMap map) { return map.isIdentity(); }); -} - /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; @@ -2434,6 +2435,8 @@ // LLVM dialect, so the LLVM type needs to me a multi-dimension array. LLVM::LLVMType elementType = unwrap(typeConverter.convertType(type.getElementType())); + if (!elementType) + return nullptr; LLVM::LLVMType arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) @@ -3070,6 +3073,7 @@ template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::isSupportedMemRefType; using Base = LoadStoreOpLowering; LogicalResult match(Operation *op) const override {