diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -430,7 +430,17 @@ LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); - assert(type.getAffineMaps().empty() && "unexpected layout map"); + + // Extract all strides and offsets and verify they are static. + int64_t offset; + SmallVector strides; + auto result = getStridesAndOffset(type, strides, offset); + (void)result; + assert(succeeded(result) && "unexpected failure in stride computation"); + assert(offset != MemRefType::getDynamicStrideOrOffset() && + "expected static offset"); + assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && + "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); @@ -438,16 +448,12 @@ auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, memory); - descr.setConstantOffset(builder, loc, 0); - - // Fill in sizes and strides, in reverse order to simplify stride - // calculation. - uint64_t runningStride = 1; - for (unsigned i = type.getRank(); i > 0; --i) { - unsigned dim = i - 1; - descr.setConstantSize(builder, loc, dim, type.getDimSize(dim)); - descr.setConstantStride(builder, loc, dim, runningStride); - runningStride *= type.getDimSize(dim); + descr.setConstantOffset(builder, loc, offset); + + // Fill in sizes and strides + for (unsigned i = 0, e = type.getRank(); i != e; ++i) { + descr.setConstantSize(builder, loc, i, type.getDimSize(i)); + descr.setConstantStride(builder, loc, i, strides[i]); } return descr; } diff --git a/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir b/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir --- a/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir +++ b/mlir/test/Conversion/GPUToNVVM/memory-attrbution.mlir @@ -92,18 +92,18 @@ // CHECK: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 - // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c6]], %[[descr4]][3, 2] - // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 - // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 2] + // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] + // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64 + // CHECK: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0] // CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 // CHECK: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1] // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 // CHECK: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1] - // CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 - // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c4]], %[[descr8]][3, 0] - // CHECK: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64 - // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c12]], %[[descr9]][4, 0] + // CHECK: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 + // CHECK: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2] + // CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] %c0 = constant 0 : index store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -24,20 +24,48 @@ // BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> return %static : memref<32x18xf32> } // ----- +// CHECK-LABEL: func @check_static_return_with_offset +// CHECK-COUNT-2: !llvm<"float*"> +// CHECK-COUNT-5: !llvm.i64 +// CHECK-SAME: -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-LABEL: func @check_static_return_with_offset +// BAREPTR-SAME: (%[[arg:.*]]: !llvm<"float*">) -> !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> { +func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> { +// CHECK: llvm.return %{{.*}} : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + +// BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[base:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(22 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// BAREPTR-NEXT: llvm.return %[[ins4]] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + return %static : memref<32x18xf32, offset:7, strides:[22,1]> +} + +// ----- + // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { // ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { // BAREPTR-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { @@ -302,7 +330,7 @@ // BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm<"float*">) { func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { // CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 -// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> // BAREPTR-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 %0 = dim %static, 0 : memref<42x32x15x13x27xf32> // CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64