diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -155,27 +155,19 @@ return failure(); Value trueBuffer = *maybeTrueBuffer; Value falseBuffer = *maybeFalseBuffer; - BaseMemRefType trueType = trueBuffer.getType().cast(); - BaseMemRefType falseType = falseBuffer.getType().cast(); - if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt()) - return op->emitError("inconsistent memory space on true/false operands"); // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { - auto trueType = trueBuffer.getType().cast(); - int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; - SmallVector dynamicStrides(trueType.getRank(), - ShapedType::kDynamicStrideOrOffset); - AffineMap stridedLayout = makeStridedLinearLayoutMap( - dynamicStrides, dynamicOffset, op->getContext()); - auto castedType = - MemRefType::get(trueType.getShape(), trueType.getElementType(), - stridedLayout, trueType.getMemorySpaceAsInt()); - trueBuffer = rewriter.create(loc, castedType, trueBuffer); + auto targetType = + bufferization::getBufferType(selectOp.getResult(), options); + if (failed(targetType)) + return failure(); + trueBuffer = + rewriter.create(loc, *targetType, trueBuffer); falseBuffer = - rewriter.create(loc, castedType, falseBuffer); + rewriter.create(loc, *targetType, falseBuffer); } replaceOpWithNewBufferizedOp( @@ -183,6 +175,31 @@ return success(); } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto selectOp = cast(op); + assert(value == selectOp.getResult() && "invalid value"); + auto trueType = bufferization::getBufferType(selectOp.getTrueValue(), + options, fixedTypes); + auto falseType = bufferization::getBufferType(selectOp.getFalseValue(), + options, fixedTypes); + if (failed(trueType) || failed(falseType)) + return failure(); + if (*trueType == *falseType) + return *trueType; + if (trueType->getMemorySpaceAsInt() != falseType->getMemorySpaceAsInt()) + return op->emitError("inconsistent memory space on true/false operands"); + + // If the buffers have different types, they differ only in their layout + // map. + auto memrefType = trueType->cast(); + return getMemRefTypeWithFullyDynamicLayout( + RankedTensorType::get(memrefType.getShape(), + memrefType.getElementType()), + memrefType.getMemorySpaceAsInt()); + } + BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None;