diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -319,12 +319,59 @@ // types. `BufferCastOp::fold` handles the same type case. if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType()) return failure(); - // If types are not cast-compatible, bail. + // If types are definitely not cast-compatible, bail. if (!CastOp::areCastCompatible(tensorLoad.memref().getType(), bufferCast.getType())) return failure(); - rewriter.replaceOpWithNewOp(bufferCast, bufferCast.getType(), - tensorLoad.memref()); + + // We already know that the types are potentially cast-compatible. However + // in case the affine maps are different, we may need to use a copy if we go + // from dynamic to static offset or stride (the canonicalization cannot know + // at this point that it is really cast compatible). + auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { + int64_t sourceOffset, targetOffset; + SmallVector sourceStrides, targetStrides; + if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || + failed(getStridesAndOffset(target, targetStrides, targetOffset))) + return false; + auto dynamicToStatic = [](int64_t a, int64_t b) { + return a == MemRefType::getDynamicStrideOrOffset() && + b != MemRefType::getDynamicStrideOrOffset(); + }; + if (dynamicToStatic(sourceOffset, targetOffset)) + return false; + for (auto sourceStride : enumerate(sourceStrides)) + if (dynamicToStatic(sourceStride.value(), + targetStrides[sourceStride.index()])) + return false; + return true; + }; + + auto tensorLoadType = tensorLoad.memref().getType().dyn_cast(); + auto bufferCastType = bufferCast.getType().dyn_cast(); + if (tensorLoadType && bufferCastType && + !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) { + MemRefType resultType = bufferCastType; + auto loc = bufferCast.getLoc(); + SmallVector dynamicOperands; + for (int i = 0; i < resultType.getRank(); ++i) { + if (resultType.getShape()[i] != ShapedType::kDynamicSize) + continue; + auto index = rewriter.createOrFold(loc, i); + Value size = rewriter.create(loc, tensorLoad, index); + if (!size.getType().isIndex()) { + size = + rewriter.create(loc, size, rewriter.getIndexType()); + } + dynamicOperands.push_back(size); + } + auto copy = + rewriter.create(loc, resultType, dynamicOperands); + rewriter.create(loc, tensorLoad.memref(), copy); + rewriter.replaceOp(bufferCast, {copy}); + } else + rewriter.replaceOpWithNewOp(bufferCast, bufferCast.getType(), + tensorLoad.memref()); return success(); } }; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -46,16 +46,18 @@ // CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// Test case: If the memrefs are cast-compatible, canonicalize. +// Test case: If the memrefs are definitely cast-compatible, canonicalize to +// cast. // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load( // CHECK-SAME: %[[M:.*]]: memref) -// CHEKC-SAME: -> memref { +// CHECK-SAME: -> memref { // CHECK-NOT: memref.tensor_load // CHECK-NOT: memref.buffer_cast // CHECK: %[[R:.*]] = memref.cast %[[M]] // CHECK-SAME: memref to memref // CHECK: return %[[R]] -func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref) +func @canonicalize_buffer_cast_of_tensor_load( + %arg0: memref) -> memref { %0 = memref.tensor_load %arg0 : memref @@ -65,6 +67,33 @@ // ----- +// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> + +// Test case: If the memrefs are potentially cast-compatible, canonicalize to +// copy. +// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy( +// CHECK-SAME: %[[M:.*]]: memref) +// CHECK-SAME: -> memref { +// CHECK-NOT: memref.tensor_load +// CHECK-NOT: memref.buffer_cast +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref +// CHECK: memref.copy %[[M]], %[[ALLOC]] +// CHECK-SAME: memref to memref +// CHECK: return %[[ALLOC]] +func @canonicalize_buffer_cast_of_tensor_load_to_copy( + %arg0: memref) + -> memref +{ + %0 = memref.tensor_load %arg0 : memref + %1 = memref.buffer_cast %0 : memref + return %1 : memref +} + +// ----- + // CHECK-LABEL: func @subview_of_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> // CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>