diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -157,11 +157,11 @@ Type descriptorType); /// Builds IR extracting the rank from the descriptor - Value rank(OpBuilder &builder, Location loc); + Value rank(OpBuilder &builder, Location loc) const; /// Builds IR setting the rank in the descriptor void setRank(OpBuilder &builder, Location loc, Value value); /// Builds IR extracting ranked memref descriptor ptr - Value memRefDescPtr(OpBuilder &builder, Location loc); + Value memRefDescPtr(OpBuilder &builder, Location loc) const; /// Builds IR setting ranked memref descriptor ptr void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); @@ -183,10 +183,13 @@ static unsigned getNumUnpackedValues() { return 2; } /// Builds IR computing the sizes in bytes (suitable for opaque allocation) - /// and appends the corresponding values into `sizes`. + /// and appends the corresponding values into `sizes`. `addressSpaces` + /// which must have the same length as `values`, is needed to handle layouts + /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)). static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, + ArrayRef addressSpaces, SmallVectorImpl &sizes); /// TODO: The following accessors don't take alignment rules between elements diff --git a/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h @@ -41,7 +41,7 @@ protected: /// Builds IR to extract a value from the struct at position pos - Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); + Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const; /// Builds IR to set a value in the struct at position pos void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -296,7 +296,7 @@ Value descriptor = builder.create(loc, descriptorType); return UnrankedMemRefDescriptor(descriptor); } -Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { +Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, @@ -304,7 +304,7 @@ setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, - Location loc) { + Location loc) const { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, @@ -341,24 +341,24 @@ void UnrankedMemRefDescriptor::computeSizes( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - ArrayRef values, SmallVectorImpl &sizes) { + ArrayRef values, ArrayRef addressSpaces, + SmallVectorImpl &sizes) { if (values.empty()) return; - + assert(values.size() == addressSpaces.size() && + "must provide address space for each descriptor"); // Cache the index type. Type indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); Value two = createIndexAttrConstant(builder, loc, indexType, 2); - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); Value indexSize = createIndexAttrConstant(builder, loc, indexType, ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); sizes.reserve(sizes.size() + values.size()); - for (UnrankedMemRefDescriptor desc : values) { + for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) { // Emit IR computing the memory necessary to store the descriptor. This // assumes the descriptor to be // { type*, type*, index, index[rank], index[rank] } @@ -366,6 +366,9 @@ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). // TODO: consider including the actual size (including eventual padding due // to data layout) into the unranked descriptor. + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, + ceilDiv(typeConverter.getPointerBitwidth(addressSpace), 8)); Value doublePointerSize = builder.create(loc, indexType, two, pointerSize); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -232,18 +232,27 @@ "expected as may original types as operands"); // Find operands of unranked memref type and store them. - SmallVector unrankedMemrefs; - for (unsigned i = 0, e = operands.size(); i < e; ++i) - if (origTypes[i].isa()) + SmallVector unrankedMemrefs; + SmallVector unrankedAddressSpaces; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (auto memRefType = origTypes[i].dyn_cast()) { unrankedMemrefs.emplace_back(operands[i]); + FailureOr addressSpace = + getTypeConverter()->getMemRefAddressSpace(memRefType); + if (failed(addressSpace)) + return failure(); + unrankedAddressSpaces.emplace_back(*addressSpace); + } + } if (unrankedMemrefs.empty()) return success(); // Compute allocation sizes. - SmallVector sizes; + SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), - unrankedMemrefs, sizes); + unrankedMemrefs, unrankedAddressSpaces, + sizes); // Get frequently used types. MLIRContext *context = builder.getContext(); diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp @@ -23,7 +23,7 @@ } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { + unsigned pos) const { return builder.create(loc, value, pos); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1329,7 +1329,7 @@ targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - targetDesc, sizes); + targetDesc, addressSpace, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), IntegerType::get(getContext(), 8), sizes.front()); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -742,8 +742,6 @@ if (failed(warpMatrixInfo)) return failure(); - Attribute memorySpace = - op.getSource().getType().cast().getMemorySpace(); bool isLdMatrixCompatible = isSharedMemory(op.getSource().getType().cast()) && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -122,9 +122,9 @@ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) // These sizes may depend on the data layout, not matching specific values. - // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)> // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] @@ -153,13 +153,12 @@ // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1] %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32> - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) // These sizes may depend on the data layout, not matching specific values. - // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -408,8 +408,8 @@ // Compute size in bytes to allocate result ranked descriptor // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8 diff --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir --- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir @@ -323,8 +323,8 @@ // Compute size in bytes to allocate result ranked descriptor // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : i64 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x i8