diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -26,6 +26,10 @@ namespace { +bool isStaticStrideOrOffset(int64_t strideOrOffset) { + return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); +} + struct AllocOpLowering : public AllocLikeOpLLVMLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), @@ -1091,11 +1095,52 @@ Type srcType, memref::ReshapeOp reshapeOp, memref::ReshapeOp::Adaptor adaptor, Value *descriptor) const { - // Conversion for statically-known shape args is performed via - // `memref_reinterpret_cast`. auto shapeMemRefType = reshapeOp.shape().getType().cast(); - if (shapeMemRefType.hasStaticShape()) - return failure(); + if (shapeMemRefType.hasStaticShape()) { + MemRefType targetMemRefType = + reshapeOp.getResult().getType().cast(); + auto llvmTargetDescriptorTy = + typeConverter->convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy) + return failure(); + + // Create descriptor. + Location loc = reshapeOp.getLoc(); + auto desc = + MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + extractPointersAndOffset(loc, rewriter, *getTypeConverter(), + reshapeOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Extract the offset and strides from the type. + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) + return rewriter.notifyMatchFailure( + reshapeOp, "failed to get stride and offset exprs"); + + if (!isStaticStrideOrOffset(offset)) + return rewriter.notifyMatchFailure(reshapeOp, + "dynamic offset is unsupported"); + if (!llvm::all_of(strides, isStaticStrideOrOffset)) + return rewriter.notifyMatchFailure(reshapeOp, + "dynamic strides are unsupported"); + + desc.setConstantOffset(rewriter, loc, offset); + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i)); + desc.setConstantStride(rewriter, loc, i, strides[i]); + } + + *descriptor = desc; + return success(); + } // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); @@ -1499,10 +1544,7 @@ for (auto &en : llvm::enumerate(dstShape)) dstDesc.setSize(rewriter, loc, en.index(), en.value()); - auto isStaticStride = [](int64_t stride) { - return !ShapedType::isDynamicStrideOrOffset(stride); - }; - if (llvm::all_of(strides, isStaticStride)) { + if (llvm::all_of(strides, isStaticStrideOrOffset)) { for (auto &en : llvm::enumerate(strides)) dstDesc.setConstantStride(rewriter, loc, en.index(), en.value()); } else if (srcType.getLayout().isIdentity() && diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -217,3 +217,38 @@ } } +// ----- + +memref.global "private" constant @__constant_3xi64 : memref<3xi64> = dense<[2, 6, 20]> + +// CHECK-LABEL: func @memref.reshape +// CHECK-SAME: %[[arg0:.*]]: memref<4x5x6xf32>) -> memref<2x6x20xf32> +func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> { + // CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<4x5x6xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + %0 = memref.get_global @__constant_3xi64 : memref<3xi64> + + // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[elem0:.*]] = llvm.extractvalue %[[cast0]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64 + // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64 + // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64 + // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64 + // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32> + %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32> + + // CHECK: return %[[cast1]] : memref<2x6x20xf32> + return %1 : memref<2x6x20xf32> +}