diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -129,6 +129,7 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto selectOp = cast(op); + Location loc = selectOp.getLoc(); // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. // TODO: It would be more efficient to copy the result of the `select` op @@ -139,6 +140,26 @@ *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); Value falseBuffer = *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); + + // The "true" and the "false" operands must have the same type. If the + // buffers have different types, they differ only in their layout map. Cast + // both of them to the most dynamic MemRef type. + if (trueBuffer.getType() != falseBuffer.getType()) { + auto trueType = trueBuffer.getType().cast(); + auto tensorType = selectOp.getTrueValue().getType().cast(); + int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; + SmallVector dynamicStrides(tensorType.getRank(), + ShapedType::kDynamicStrideOrOffset); + AffineMap stridedLayout = makeStridedLinearLayoutMap( + dynamicStrides, dynamicOffset, op->getContext()); + BaseMemRefType castedType = bufferization::getMemRefType( + tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout), + trueType.getMemorySpace()); + trueBuffer = rewriter.create(loc, castedType, trueBuffer); + falseBuffer = + rewriter.create(loc, castedType, falseBuffer); + } + replaceOpWithNewBufferizedOp( rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); return success(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -105,4 +105,18 @@ return %0 : tensor<10xf32> } +// ----- +// CHECK-LABEL: func @select_different_tensors( +// CHECK-SAME: %[[t:.*]]: tensor +func @select_different_tensors(%t: tensor, %sz: index, %c: i1) -> tensor { + // CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref + // CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref + %0 = linalg.init_tensor [%sz] : tensor + + // A cast must be inserted because %t and %0 have different memref types. + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref to memref + // CHECK: arith.select %{{.*}}, %[[casted]], %[[m]] + %1 = arith.select %c, %0, %t : tensor + return %1 : tensor +}