diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -495,6 +495,19 @@ FailureOr getBufferType(Value value, const BufferizationOptions &options); +/// Return the buffer type for a given Value (tensor) after bufferization +/// without bufferizing any IR. If at any point during the type computation, the +/// type of a value in `fixedTypes` in required, the mapped type is used. +/// +/// Note: It should be sufficient to call `getBuffer()->getType()` in most +/// cases. However, when a buffer type should be predicted without modifying any +/// IR, this function can be used. +/// +/// This function is a wrapper around BufferizableOpInterface::getBufferType. +FailureOr +getBufferType(Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes); + /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, @@ -551,7 +564,8 @@ /// BufferizableOpInterface::getBufferType. Should not be called from other /// places. FailureOr -defaultGetBufferType(Value value, const BufferizationOptions &options); +defaultGetBufferType(Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -350,12 +350,14 @@ /*retType=*/"FailureOr", /*methodName=*/"getBufferType", /*args=*/(ins "Value":$value, - "const BufferizationOptions &":$options), + "const BufferizationOptions &":$options, + "const DenseMap":$fixedTypes), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(getOwnerOfValue(value) == $_op.getOperation() && "expected that value belongs to this op"); - return bufferization::detail::defaultGetBufferType(value, options); + return bufferization::detail::defaultGetBufferType( + value, options, fixedTypes); }] >, ]; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -92,7 +92,8 @@ OpOperand &opOperand, const AnalysisState &state); FailureOr getBufferType( - Value value, const BufferizationOptions &options); + Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes); RankedTensorType getType() { return getResult().getType().cast(); @@ -323,7 +324,8 @@ } FailureOr getBufferType( - Value value, const BufferizationOptions &options) { + Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) { return getMemref().getType().cast(); } }]; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -568,7 +568,8 @@ } FailureOr bufferization::detail::defaultGetBufferType( - Value value, const BufferizationOptions &options) { + Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) { assert(value.getType().isa() && "expected tensor type"); // No further analysis is possible for a block argument. @@ -587,7 +588,7 @@ // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliasingOperands.front()->get(); - return getBufferType(equivalentOperand, options); + return getBufferType(equivalentOperand, options, fixedTypes); } // If we do not know the memory space and there is no default memory space, @@ -602,11 +603,26 @@ /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { + DenseMap fixedTypes; + return getBufferType(value, options, fixedTypes); +} + +/// Return the buffer type for a given Value (tensor) after bufferization. +FailureOr bufferization::getBufferType( + Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) { assert(value.getType().isa() && "unexpected non-tensor type"); + + // If the `value` is in `fixedTypes`, return the mapped type. + const auto &it = fixedTypes.find(value); + if (it != fixedTypes.end()) + return it->second; + + // Try querying BufferizableOpInterface. Operation *op = getOwnerOfValue(value); auto bufferizableOp = options.dynCastBufferizableOp(op); if (bufferizableOp) - return bufferizableOp.getBufferType(value, options); + return bufferizableOp.getBufferType(value, options, fixedTypes); // Op is not bufferizable. if (!options.defaultMemorySpace.has_value()) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -170,7 +170,7 @@ } // Create memory allocation. - auto allocType = getBufferType(getResult(), options); + auto allocType = bufferization::getBufferType(getResult(), options); if (failed(allocType)) return failure(); SmallVector dynamicDims = getDynamicSizes(); @@ -233,8 +233,9 @@ return {}; } -FailureOr -AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) { +FailureOr AllocTensorOp::getBufferType( + Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) { assert(value == getResult() && "invalid value"); // Compute memory space of this allocation. @@ -242,7 +243,8 @@ if (getMemorySpace().has_value()) { memorySpace = *getMemorySpace(); } else if (getCopy()) { - auto copyBufferType = bufferization::getBufferType(getCopy(), options); + auto copyBufferType = + bufferization::getBufferType(getCopy(), options, fixedTypes); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpaceAsInt(); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -472,15 +472,76 @@ } FailureOr - getBufferType(Operation *op, Value value, - const BufferizationOptions &options) const { + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { auto forOp = cast(op); - // TODO: Only block arguments supported at the moment. - if (value.isa()) + assert(getOwnerOfValue(value) == op && "invalid value"); + assert(value.getType().isa() && "expected tensor type"); + + // Get result/argument number. + unsigned resultNum; + if (auto bbArg = value.dyn_cast()) { + resultNum = + forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg)) + .getResultNumber(); + } else { + resultNum = value.cast().getResultNumber(); + } + + // Determine the buffer type of the init_arg. + Value initArg = forOp.getInitArgs()[resultNum]; + auto initArgBufferType = + bufferization::getBufferType(initArg, options, fixedTypes); + if (failed(initArgBufferType)) return failure(); - auto bbArg = value.cast(); - return bufferization::getBufferType( - forOp.getOpOperandForRegionIterArg(bbArg).get(), options); + + // Fix the iter_arg type, so that recursive lookups return the buffer type + // of the init_arg. This is to avoid infinite loops when calculating the + // buffer type of the yielded value. + // + // Note: For more precise layout map computation, a fixpoint iteration could + // be done (i.e., re-computing the yielded buffer type until the bufferized + // iter_arg type no longer changes). This current implementation immediately + // switches to a fully dynamic layout map when a mismatch between bufferized + // init_arg type and bufferized yield value type is detected. + DenseMap newFixedTypes(fixedTypes); + newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType; + + // Compute the buffer type of the yielded value. + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + Value yieldedValue = yieldOp.getOperand(resultNum); + BaseMemRefType yieldedValueBufferType; + if (yieldedValue.getType().isa()) { + // scf.yield was already bufferized. + yieldedValueBufferType = yieldedValue.getType().cast(); + } else { + auto maybeBufferType = + bufferization::getBufferType(yieldedValue, options, newFixedTypes); + if (failed(maybeBufferType)) + return failure(); + yieldedValueBufferType = *maybeBufferType; + } + + // If yielded type and init_arg type are the same, use that type directly. + if (*initArgBufferType == yieldedValueBufferType) + return yieldedValueBufferType; + + // If there is a mismatch between the yielded buffer type and the iter_arg + // buffer type, the buffer type must be promoted to a fully dynamic layout + // map. + auto yieldedRanked = yieldedValueBufferType.cast(); +#ifndef NDEBUG + auto iterRanked = initArgBufferType->cast(); + assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && + "expected same shape"); + assert(yieldedRanked.getMemorySpaceAsInt() == + iterRanked.getMemorySpaceAsInt() && + "expected same memory space"); +#endif // NDEBUG + return getMemRefTypeWithFullyDynamicLayout( + forOp.getRegionIterArgs()[resultNum].getType().cast(), + yieldedRanked.getMemorySpaceAsInt()); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -499,13 +560,22 @@ return failure(); SmallVector initArgs = *maybeInitArgs; + // Cast init_args if necessary. + SmallVector castedInitArgs; + for (const auto &it : llvm::enumerate(initArgs)) { + Value initArg = it.value(); + auto targetType = + bufferization::getBufferType(forOp->getResult(it.index()), options); + if (failed(targetType)) + return failure(); + castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); + } + // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), initArgs); + forOp.getStep(), castedInitArgs); newForOp->setAttrs(forOp->getAttrs()); - ValueRange initArgsRange(initArgs); - TypeRange initArgsTypes(initArgsRange); Block *loopBody = &newForOp.getLoopBody().front(); // Set up new iter_args. The loop body uses tensors, so wrap the (memref) @@ -904,10 +974,8 @@ return failure(); Value buffer = *maybeBuffer; if (auto forOp = dyn_cast(yieldOp->getParentOp())) { - FailureOr resultType = - cast(forOp.getOperation()) - .getBufferType(forOp.getRegionIterArgs()[it.index()], - options); + FailureOr resultType = bufferization::getBufferType( + forOp.getRegionIterArgs()[it.index()], options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -293,7 +293,7 @@ // Take a subview of the source buffer. auto resultMemrefType = - getBufferType(op, extractSliceOp.getResult(), options); + bufferization::getBufferType(extractSliceOp.getResult(), options); if (failed(resultMemrefType)) return failure(); Value subView = rewriter.create( @@ -305,12 +305,12 @@ } FailureOr - getBufferType(Operation *op, Value value, - const BufferizationOptions &options) const { + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { auto extractSliceOp = cast(op); assert(value == extractSliceOp.getResult() && "invalid value"); - auto srcMemrefType = - bufferization::getBufferType(extractSliceOp.getSource(), options); + auto srcMemrefType = bufferization::getBufferType( + extractSliceOp.getSource(), options, fixedTypes); if (failed(srcMemrefType)) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -742,3 +742,25 @@ } return %r : tensor } + +// ----- + +// We just check that this example bufferizes to valid IR. + +// CHECK-LABEL: func @scf_for_buffer_type_mismatch +func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = bufferization.alloc_tensor(%sz) : tensor + %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor to tensor + // init_arg and iter_arg have different buffer types. This must be resolved + // with casts. + %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t = %e2) -> tensor { + %s = "test.dummy"() : () -> (index) + %e = tensor.extract_slice %t[1][%s][1] : tensor to tensor + scf.yield %e : tensor + } + %x = tensor.extract %r[%c1] : tensor + return %x : f32 +}