diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -27,4 +27,26 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Bufferization/IR/BufferizationOps.h.inc" +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { +/// Try to cast the given ranked MemRef-typed value to the given ranked MemRef +/// type. Insert a reallocation + copy if it cannot be statically guaranteed +/// that a direct cast would be valid. +/// +/// E.g., when casting from a ranked MemRef type with dynamic layout to a ranked +/// MemRef type with static layout, it is not statically known whether the cast +/// will succeed or not. Such `memref.cast` ops may fail at runtime. This +/// function never generates such casts and conservatively inserts a copy. +/// +/// This function returns `failure()` in case of unsupported casts. E.g., casts +/// with differing element types or memory spaces. +FailureOr castOrReallocMemRefValue(OpBuilder &b, Value value, + MemRefType type); +} // namespace bufferization +} // namespace mlir + #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -1,4 +1,3 @@ - //===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. @@ -13,6 +12,73 @@ using namespace mlir; using namespace mlir::bufferization; +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +FailureOr +mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, + MemRefType destType) { + auto srcType = value.getType().cast(); + + // Casting to the same type, nothing to do. + if (srcType == destType) + return value; + + // Element type, rank and memory space must match. + if (srcType.getElementType() != destType.getElementType()) + return failure(); + if (srcType.getMemorySpaceAsInt() != destType.getMemorySpaceAsInt()) + return failure(); + if (srcType.getRank() != destType.getRank()) + return failure(); + + // In case the affine maps are different, we may need to use a copy if we go + // from dynamic to static offset or stride (the canonicalization cannot know + // at this point that it is really cast compatible). + auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { + int64_t sourceOffset, targetOffset; + SmallVector sourceStrides, targetStrides; + if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || + failed(getStridesAndOffset(target, targetStrides, targetOffset))) + return false; + auto dynamicToStatic = [](int64_t a, int64_t b) { + return a == MemRefType::getDynamicStrideOrOffset() && + b != MemRefType::getDynamicStrideOrOffset(); + }; + if (dynamicToStatic(sourceOffset, targetOffset)) + return false; + for (auto it : zip(sourceStrides, targetStrides)) + if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) + return false; + return true; + }; + + // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To + // ensure that we only generate casts that always succeed at runtime, we check + // a fix extra conditions in `isGuaranteedCastCompatible`. + if (memref::CastOp::areCastCompatible(srcType, destType) && + isGuaranteedCastCompatible(srcType, destType)) { + Value casted = b.create(value.getLoc(), destType, value); + return casted; + } + + auto loc = value.getLoc(); + SmallVector dynamicOperands; + for (int i = 0; i < destType.getRank(); ++i) { + if (destType.getShape()[i] != ShapedType::kDynamicSize) + continue; + auto index = b.createOrFold(loc, i); + Value size = b.create(loc, value, index); + dynamicOperands.push_back(size); + } + // TODO: Use alloc/memcpy callback from BufferizationOptions if called via + // BufferizableOpInterface impl of ToMemrefOp. + Value copy = b.create(loc, destType, dynamicOperands); + b.create(loc, value, copy); + return copy; +} + //===----------------------------------------------------------------------===// // CloneOp //===----------------------------------------------------------------------===// @@ -191,67 +257,39 @@ if (!memrefToTensor) return failure(); - // A memref_to_tensor + tensor_to_memref with same types can be folded without - // inserting a cast. - if (memrefToTensor.memref().getType() == toMemref.getType()) { - if (!allowSameType) - // Function can be configured to only handle cases where a cast is needed. + Type srcType = memrefToTensor.memref().getType(); + Type destType = toMemref.getType(); + + // Function can be configured to only handle cases where a cast is needed. + if (!allowSameType && srcType == destType) + return failure(); + + auto rankedSrcType = srcType.dyn_cast(); + auto rankedDestType = destType.dyn_cast(); + auto unrankedSrcType = srcType.dyn_cast(); + + // Ranked memref -> Ranked memref cast. + if (rankedSrcType && rankedDestType) { + FailureOr replacement = castOrReallocMemRefValue( + rewriter, memrefToTensor.memref(), rankedDestType); + if (failed(replacement)) return failure(); - rewriter.replaceOp(toMemref, memrefToTensor.memref()); + + rewriter.replaceOp(toMemref, *replacement); return success(); } - // If types are definitely not cast-compatible, bail. - if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), - toMemref.getType())) + // Unranked memref -> Ranked memref cast: May require a copy. + // TODO: Not implemented at the moment. + if (unrankedSrcType && rankedDestType) return failure(); - // We already know that the types are potentially cast-compatible. However - // in case the affine maps are different, we may need to use a copy if we go - // from dynamic to static offset or stride (the canonicalization cannot know - // at this point that it is really cast compatible). - auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { - int64_t sourceOffset, targetOffset; - SmallVector sourceStrides, targetStrides; - if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || - failed(getStridesAndOffset(target, targetStrides, targetOffset))) - return false; - auto dynamicToStatic = [](int64_t a, int64_t b) { - return a == MemRefType::getDynamicStrideOrOffset() && - b != MemRefType::getDynamicStrideOrOffset(); - }; - if (dynamicToStatic(sourceOffset, targetOffset)) - return false; - for (auto it : zip(sourceStrides, targetStrides)) - if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) - return false; - return true; - }; - - auto memrefToTensorType = - memrefToTensor.memref().getType().dyn_cast(); - auto toMemrefType = toMemref.getType().dyn_cast(); - if (memrefToTensorType && toMemrefType && - !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { - MemRefType resultType = toMemrefType; - auto loc = toMemref.getLoc(); - SmallVector dynamicOperands; - for (int i = 0; i < resultType.getRank(); ++i) { - if (resultType.getShape()[i] != ShapedType::kDynamicSize) - continue; - auto index = rewriter.createOrFold(loc, i); - Value size = rewriter.create(loc, memrefToTensor, index); - dynamicOperands.push_back(size); - } - // TODO: Use alloc/memcpy callback from BufferizationOptions if called via - // BufferizableOpInterface impl of ToMemrefOp. - auto copy = - rewriter.create(loc, resultType, dynamicOperands); - rewriter.create(loc, memrefToTensor.memref(), copy); - rewriter.replaceOp(toMemref, {copy}); - } else - rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), - memrefToTensor.memref()); + // Unranked memref -> unranked memref cast + // Ranked memref -> unranked memref cast: No copy needed. + assert(memref::CastOp::areCastCompatible(srcType, destType) && + "expected that types are cast compatible"); + rewriter.replaceOpWithNewOp(toMemref, destType, + memrefToTensor.memref()); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -45,9 +45,28 @@ addSourceMaterialization(materializeToTensor); addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); + assert(inputs.size() == 1 && "expected exactly one input"); + + if (auto inputType = inputs[0].getType().dyn_cast()) { + // MemRef to MemRef cast. + assert(inputType != type && "expected different types"); + // Unranked to ranked and ranked to unranked casts must be explicit. + auto rankedDestType = type.dyn_cast(); + if (!rankedDestType) + return nullptr; + FailureOr replacement = + castOrReallocMemRefValue(builder, inputs[0], rankedDestType); + if (failed(replacement)) + return nullptr; + return *replacement; + } + + if (inputs[0].getType().isa()) { + // Tensor to MemRef cast. + return builder.create(loc, type, inputs[0]); + } + + llvm_unreachable("only tensor/memref input types supported"); }); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir @@ -26,3 +26,77 @@ "test.sink"(%0) : (tensor) -> () return } + +// ----- + +// CHECK: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-LABEL: func @dyn_layout_to_no_layout_cast( +// CHECK-SAME: %[[arg:.*]]: memref) +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref +// CHECK: memref.copy %[[arg]], %[[alloc]] +// CHECK: return %[[alloc]] +#map1 = affine_map<(d0)[s0] -> (d0 + s0)> +func @dyn_layout_to_no_layout_cast(%m: memref) -> memref { + %0 = bufferization.to_tensor %m : memref + %1 = bufferization.to_memref %0 : memref + return %1 : memref +} + +// ----- + +// CHECK: #[[$map2:.*]] = affine_map<(d0)[s0] -> (d0 * 100 + s0)> +// CHECK-LABEL: func @fancy_layout_to_no_layout_cast( +// CHECK-SAME: %[[arg:.*]]: memref) +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref +// CHECK: memref.copy %[[arg]], %[[alloc]] +// CHECK: return %[[alloc]] +#map2 = affine_map<(d0)[s0] -> (d0 * 100 + s0)> +func @fancy_layout_to_no_layout_cast(%m: memref) -> memref { + %0 = bufferization.to_tensor %m : memref + %1 = bufferization.to_memref %0 : memref + return %1 : memref +} + +// ----- + +// CHECK: #[[$map3:.*]] = affine_map<(d0)[s0] -> (d0 + 25)> +// CHECK-LABEL: func @static_layout_to_no_layout_cast( +// CHECK-SAME: %[[arg:.*]]: memref) +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = memref.dim %[[arg]], %[[c0]] +// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref +// CHECK: memref.copy %[[arg]], %[[alloc]] +// CHECK: return %[[alloc]] +#map3 = affine_map<(d0)[s0] -> (d0 + 25)> +func @static_layout_to_no_layout_cast(%m: memref) -> memref { + %0 = bufferization.to_tensor %m : memref + %1 = bufferization.to_memref %0 : memref + return %1 : memref +} + +// ----- + +// TODO: to_memref with layout maps not supported yet. This should fold to a +// memref.cast. +#map4 = affine_map<(d0)[s0] -> (d0 + s0)> +func @no_layout_to_dyn_layout_cast(%m: memref) -> memref { + %0 = bufferization.to_tensor %m : memref + // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}} + %1 = bufferization.to_memref %0 : memref + // expected-note @+1 {{see existing live user here}} + return %1 : memref +} + +// ----- + +func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref { + // expected-note @+1 {{prior use here}} + %0 = bufferization.to_tensor %m : memref<*xf32> + // expected-error @+1 {{expects different type than prior uses: 'tensor' vs 'tensor<*xf32>'}} + %1 = bufferization.to_memref %0 : memref + return %1 : memref +}