diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2581,6 +2581,10 @@ auto viewOp = cast(op); ViewOpOperandAdaptor adaptor(operands); + auto sourceMemRefType = viewOp.source().getType().cast(); + auto sourceElementTy = + typeConverter.convertType(sourceMemRefType.getElementType()) + .dyn_cast(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) @@ -2591,6 +2595,14 @@ return op->emitWarning("Target descriptor type not converted to LLVM"), failure(); + int64_t baseOffset; + { + SmallVector strides; + if (failed(getStridesAndOffset(sourceMemRefType, strides, baseOffset))) + return op->emitWarning("cannot cast base type to non-strided shape"), + failure(); + } + int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); @@ -2608,9 +2620,16 @@ targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. - extracted = sourceMemRef.alignedPtr(rewriter, loc); + Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); + + if (baseOffset > 0) + alignedPtr = rewriter.create( + loc, sourceElementTy.getPointerTo(), alignedPtr, + ValueRange{createIndexConstant(rewriter, loc, baseOffset)}); + bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(), extracted); + loc, targetElementTy.getPointerTo(), alignedPtr); + targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: Copy the offset in aligned pointer. @@ -2620,11 +2639,11 @@ auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); - Value baseOffset = !hasDynamicOffset - ? createIndexConstant(rewriter, loc, offset) - // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.front(); - targetMemRef.setOffset(rewriter, loc, baseOffset); + Value targetOffset = !hasDynamicOffset + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); + targetMemRef.setOffset(rewriter, loc, targetOffset); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2701,12 +2701,24 @@ auto baseType = op.getOperand(0).getType().cast(); auto viewType = op.getResult().getType().cast(); - // The base memref should have identity layout map (or none). - if (baseType.getAffineMaps().size() > 1 || - (baseType.getAffineMaps().size() == 1 && - !baseType.getAffineMaps()[0].isIdentity())) + // The base memref should have at most one affine map with stride 1. + if (baseType.getAffineMaps().size() > 1) return op.emitError("unsupported map for base memref type ") << baseType; + if (baseType.getAffineMaps().size() == 1) { + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(baseType, strides, offset))) + return op.emitError("base type ") << viewType << " is not strided"; + + if (offset == MemRefType::getDynamicStrideOrOffset()) + return op.emitError("base type ") << viewType << " has dynamic offset"; + + if (strides.size() != 1 || strides.back() != 1) + return op.emitError("base type ") + << viewType << " has non-trivial strides"; + } + // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1185,3 +1185,20 @@ // CHECK-NEXT: llvm.return %[[ARG]] return %1 : f16 } + +// ----- + +// CHECK-LABEL: func @view_base_offset( +func @view_base_offset(%0: memref<2045xi8, offset: 3, strides: [1]>) { + // Test static base offset + + // CHECK: %[[DESC:[0-9]+]] = llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[BASE:[0-9]+]] = llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: %[[OFFSET:[0-9]+]] = llvm.mlir.constant(3 : index) : !llvm.i64 + // CHECK: %[[PTR:[0-9]+]] = llvm.getelementptr %[[BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> + // CHECK: %[[CASTED:[0-9]+]] = llvm.bitcast %[[PTR]] : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %[[CASTED]], %{{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + %8 = view %0[][] : memref<2045xi8, offset: 3, strides: [1]> to memref<2x3xf32> + + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -11,6 +11,7 @@ // CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> // CHECK-DAG: #[[VIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)> // CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> +// CHECK-DAG: #[[VIEW_MAP4:map[0-9]+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> @@ -703,6 +704,12 @@ // CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]> %5 = view %0[][] : memref<2048xi8> to memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> + + // Test static sizes and static offset with base offset. + %6 = subview %0[][][] : memref<2048xi8> to memref<2045xi8, offset: 3, strides: [1]> + // CHECK: %{{.*}} = std.view %6[][] : memref<2045xi8, #[[VIEW_MAP4]]> to memref<64x4xf32, #[[VIEW_MAP1]]> + %7 = view %6[][] + : memref<2045xi8, offset: 3, strides: [1]> to memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -957,7 +957,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> - // expected-error@+1 {{unsupported map for base memref}} + // expected-error@+1 {{base type 'memref (d0 * 4 + d1 + s0)>>' is not strided}} %1 = view %0[][%arg0, %arg1] : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> to memref (d0 * 4 + d1 + s0)>>