diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -27,16 +27,78 @@ namespace scf { namespace { -// bufferization.to_memref is not allowed to change the rank. -static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { -#ifndef NDEBUG - auto rankedTensorType = tensor.getType().dyn_cast(); - assert((!rankedTensorType || (memrefType.cast().getRank() == - rankedTensorType.getRank())) && - "to_memref would be invalid: mismatching ranks"); -#endif +/// Helper function for loop bufferization. Cast the given buffer to the given +/// memref type. +static Value castBuffer(OpBuilder &b, Value buffer, Type type) { + assert(type.isa() && "expected BaseMemRefType"); + assert(buffer.getType().isa() && "expected BaseMemRefType"); + // If the buffer already has the correct type, no cast is needed. + if (buffer.getType() == type) + return buffer; + // TODO: In case `type` has a layout map that is not the fully dynamic + // one, we may not be able to cast the buffer. In that case, the loop + // iter_arg's layout map must be changed (see uses of `castBuffer`). + assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && + "scf.while op bufferization: cast incompatible"); + return b.create(buffer.getLoc(), type, buffer).getResult(); } +/// Bufferization of scf.condition. +struct ConditionOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Condition operands always bufferize inplace. Otherwise, an alloc + copy + // may be generated inside the block. We should not return/yield allocations + // when possible. + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + auto conditionOp = cast(op); + auto whileOp = cast(conditionOp->getParentOp()); + + SmallVector newArgs; + for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { + Value value = it.value(); + if (value.getType().isa()) { + FailureOr maybeBuffer = getBuffer(rewriter, value, options); + if (failed(maybeBuffer)) + return failure(); + FailureOr resultType = bufferization::getBufferType( + whileOp.getAfterArguments()[it.index()], options); + if (failed(resultType)) + return failure(); + Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); + newArgs.push_back(buffer); + } else { + newArgs.push_back(value); + } + } + + replaceOpWithNewBufferizedOp( + rewriter, op, conditionOp.getCondition(), newArgs); + return success(); + } +}; + /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not /// fully implemented at the moment. struct ExecuteRegionOpInterface @@ -283,22 +345,6 @@ return result; } -/// Helper function for loop bufferization. Cast the given buffer to the given -/// memref type. -static Value castBuffer(OpBuilder &b, Value buffer, Type type) { - assert(type.isa() && "expected BaseMemRefType"); - assert(buffer.getType().isa() && "expected BaseMemRefType"); - // If the buffer already has the correct type, no cast is needed. - if (buffer.getType() == type) - return buffer; - // TODO: In case `type` has a layout map that is not the fully dynamic - // one, we may not be able to cast the buffer. In that case, the loop - // iter_arg's layout map must be changed (see uses of `castBuffer`). - assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && - "scf.while op bufferization: cast incompatible"); - return b.create(buffer.getLoc(), type, buffer).getResult(); -} - /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr> @@ -319,60 +365,10 @@ return result; } -/// Helper function for loop bufferization. Compute the buffer that should be -/// yielded from a loop block (loop body or loop condition). -static FailureOr getYieldedBuffer(RewriterBase &rewriter, Value tensor, - BaseMemRefType type, - const BufferizationOptions &options) { - assert(tensor.getType().isa() && "expected tensor"); - ensureToMemrefOpIsValid(tensor, type); - FailureOr yieldedVal = getBuffer(rewriter, tensor, options); - if (failed(yieldedVal)) - return failure(); - return castBuffer(rewriter, *yieldedVal, type); -} - -/// Helper function for loop bufferization. Given a range of values, apply -/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified -/// value in the result vector. -static FailureOr> -convertTensorValues(ValueRange values, const DenseSet &tensorIndices, - llvm::function_ref(Value, int64_t)> func) { - SmallVector result; - for (const auto &it : llvm::enumerate(values)) { - size_t idx = it.index(); - Value val = it.value(); - if (tensorIndices.contains(idx)) { - FailureOr maybeVal = func(val, idx); - if (failed(maybeVal)) - return failure(); - result.push_back(*maybeVal); - } else { - result.push_back(val); - } - } - return result; -} - -/// Helper function for loop bufferization. Given a list of pre-bufferization -/// yielded values, compute the list of bufferized yielded values. -FailureOr> -getYieldedValues(RewriterBase &rewriter, ValueRange values, - TypeRange bufferizedTypes, - const DenseSet &tensorIndices, - const BufferizationOptions &options) { - return convertTensorValues( - values, tensorIndices, [&](Value val, int64_t index) { - return getYieldedBuffer(rewriter, val, - bufferizedTypes[index].cast(), - options); - }); -} - /// Helper function for loop bufferization. Given a list of bbArgs of the new /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into /// ToTensorOps, so that the block body can be moved over to the new op. -SmallVector +static SmallVector getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, const DenseSet &tensorIndices) { SmallVector result; @@ -390,6 +386,74 @@ return result; } +/// Compute the bufferized type of a loop iter_arg. This type must be equal to +/// the bufferized type of the corresponding init_arg and the bufferized type +/// of the corresponding yielded value. +/// +/// This function uses bufferization::getBufferType to compute the bufferized +/// type of the init_arg and of the yielded value. (The computation of the +/// usually requires computing the bufferized type of the corresponding +/// iter_arg; the implementation of getBufferType traces back the use-def chain +/// of the given value and computes a buffer type along the way.) If both buffer +/// types are equal, no casts are needed the computed buffer type can be used +/// directly. Otherwise, the buffer types can only differ in their layout map +/// and a cast must be inserted. +static FailureOr computeLoopRegionIterArgBufferType( + BlockArgument iterArg, Value initArg, Value yieldedValue, + const BufferizationOptions &options, + const DenseMap &fixedTypes) { + // Determine the buffer type of the init_arg. + auto initArgBufferType = + bufferization::getBufferType(initArg, options, fixedTypes); + if (failed(initArgBufferType)) + return failure(); + + // Fix the iter_arg type, so that recursive lookups return the buffer type + // of the init_arg. This is to avoid infinite loops when calculating the + // buffer type of the yielded value. + // + // Note: For more precise layout map computation, a fixpoint iteration could + // be done (i.e., re-computing the yielded buffer type until the bufferized + // iter_arg type no longer changes). This current implementation immediately + // switches to a fully dynamic layout map when a mismatch between bufferized + // init_arg type and bufferized yield value type is detected. + DenseMap newFixedTypes(fixedTypes); + newFixedTypes[iterArg] = *initArgBufferType; + + // Compute the buffer type of the yielded value. + BaseMemRefType yieldedValueBufferType; + if (yieldedValue.getType().isa()) { + // scf.yield was already bufferized. + yieldedValueBufferType = yieldedValue.getType().cast(); + } else { + auto maybeBufferType = + bufferization::getBufferType(yieldedValue, options, newFixedTypes); + if (failed(maybeBufferType)) + return failure(); + yieldedValueBufferType = *maybeBufferType; + } + + // If yielded type and init_arg type are the same, use that type directly. + if (*initArgBufferType == yieldedValueBufferType) + return yieldedValueBufferType; + + // If there is a mismatch between the yielded buffer type and the iter_arg + // buffer type, the buffer type must be promoted to a fully dynamic layout + // map. + auto yieldedRanked = yieldedValueBufferType.cast(); +#ifndef NDEBUG + auto iterRanked = initArgBufferType->cast(); + assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && + "expected same shape"); + assert(yieldedRanked.getMemorySpaceAsInt() == + iterRanked.getMemorySpaceAsInt() && + "expected same memory space"); +#endif // NDEBUG + return getMemRefTypeWithFullyDynamicLayout( + iterArg.getType().cast(), + yieldedRanked.getMemorySpaceAsInt()); +} + /// Bufferization of scf.for. Replace with a new scf.for that operates on /// memrefs. struct ForOpInterface @@ -507,60 +571,14 @@ resultNum = value.cast().getResultNumber(); } - // Determine the buffer type of the init_arg. - Value initArg = forOp.getInitArgs()[resultNum]; - auto initArgBufferType = - bufferization::getBufferType(initArg, options, fixedTypes); - if (failed(initArgBufferType)) - return failure(); - - // Fix the iter_arg type, so that recursive lookups return the buffer type - // of the init_arg. This is to avoid infinite loops when calculating the - // buffer type of the yielded value. - // - // Note: For more precise layout map computation, a fixpoint iteration could - // be done (i.e., re-computing the yielded buffer type until the bufferized - // iter_arg type no longer changes). This current implementation immediately - // switches to a fully dynamic layout map when a mismatch between bufferized - // init_arg type and bufferized yield value type is detected. - DenseMap newFixedTypes(fixedTypes); - newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType; - - // Compute the buffer type of the yielded value. + // Compute the bufferized type. auto yieldOp = cast(forOp.getLoopBody().front().getTerminator()); Value yieldedValue = yieldOp.getOperand(resultNum); - BaseMemRefType yieldedValueBufferType; - if (yieldedValue.getType().isa()) { - // scf.yield was already bufferized. - yieldedValueBufferType = yieldedValue.getType().cast(); - } else { - auto maybeBufferType = - bufferization::getBufferType(yieldedValue, options, newFixedTypes); - if (failed(maybeBufferType)) - return failure(); - yieldedValueBufferType = *maybeBufferType; - } - - // If yielded type and init_arg type are the same, use that type directly. - if (*initArgBufferType == yieldedValueBufferType) - return yieldedValueBufferType; - - // If there is a mismatch between the yielded buffer type and the iter_arg - // buffer type, the buffer type must be promoted to a fully dynamic layout - // map. - auto yieldedRanked = yieldedValueBufferType.cast(); -#ifndef NDEBUG - auto iterRanked = initArgBufferType->cast(); - assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) && - "expected same shape"); - assert(yieldedRanked.getMemorySpaceAsInt() == - iterRanked.getMemorySpaceAsInt() && - "expected same memory space"); -#endif // NDEBUG - return getMemRefTypeWithFullyDynamicLayout( - forOp.getRegionIterArgs()[resultNum].getType().cast(), - yieldedRanked.getMemorySpaceAsInt()); + BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; + Value initArg = forOp.getInitArgs()[resultNum]; + return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue, + options, fixedTypes); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -800,8 +818,6 @@ return success(); } - // TODO: Implement getBufferType interface method and infer buffer types. - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto whileOp = cast(op); @@ -826,6 +842,17 @@ return failure(); SmallVector initArgs = *maybeInitArgs; + // Cast init_args if necessary. + SmallVector castedInitArgs; + for (const auto &it : llvm::enumerate(initArgs)) { + Value initArg = it.value(); + auto targetType = bufferization::getBufferType( + whileOp.getBeforeArguments()[it.index()], options); + if (failed(targetType)) + return failure(); + castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); + } + // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { @@ -834,13 +861,14 @@ })); // Construct a new scf.while op with memref instead of tensor values. - ValueRange argsRangeBefore(initArgs); + ValueRange argsRangeBefore(castedInitArgs); TypeRange argsTypesBefore(argsRangeBefore); - auto newWhileOp = rewriter.create(whileOp.getLoc(), - argsTypesAfter, initArgs); + auto newWhileOp = rewriter.create( + whileOp.getLoc(), argsTypesAfter, castedInitArgs); // Add before/after regions to the new op. - SmallVector bbArgLocsBefore(initArgs.size(), whileOp.getLoc()); + SmallVector bbArgLocsBefore(castedInitArgs.size(), + whileOp.getLoc()); SmallVector bbArgLocsAfter(argsTypesAfter.size(), whileOp.getLoc()); Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); @@ -856,19 +884,6 @@ rewriter, newWhileOp.getBeforeArguments(), indicesBefore); rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs); - // Update scf.condition of new loop. - auto newConditionOp = newWhileOp.getConditionOp(); - rewriter.setInsertionPoint(newConditionOp); - // Only equivalent buffers or new buffer allocations may be yielded to the - // "after" region. - // TODO: This could be relaxed for better bufferization results. - FailureOr> newConditionArgs = - getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, - indicesAfter, options); - if (failed(newConditionArgs)) - return failure(); - newConditionOp.getArgsMutable().assign(*newConditionArgs); - // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block // in ToTensorOps. @@ -877,25 +892,51 @@ rewriter, newWhileOp.getAfterArguments(), indicesAfter); rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs); - // Update scf.yield of the new loop. - auto newYieldOp = newWhileOp.getYieldOp(); - rewriter.setInsertionPoint(newYieldOp); - // Only equivalent buffers or new buffer allocations may be yielded to the - // "before" region. - // TODO: This could be relaxed for better bufferization results. - FailureOr> newYieldValues = - getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, - indicesBefore, options); - if (failed(newYieldValues)) - return failure(); - newYieldOp.getResultsMutable().assign(*newYieldValues); - // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); return success(); } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto whileOp = cast(op); + assert(getOwnerOfValue(value) == op && "invalid value"); + assert(value.getType().isa() && "expected tensor type"); + + // Case 1: Block argument of the "before" region. + if (auto bbArg = value.cast()) { + if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { + Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; + auto yieldOp = whileOp.getYieldOp(); + Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); + return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue, + options, fixedTypes); + } + } + + // Case 2: OpResult of the loop or block argument of the "after" region. + // The bufferized "after" bbArg type can be directly computed from the + // bufferized "before" bbArg type. + unsigned resultNum; + if (auto opResult = value.dyn_cast()) { + resultNum = opResult.getResultNumber(); + } else if (value.cast().getOwner()->getParent() == + &whileOp.getAfter()) { + resultNum = value.cast().getArgNumber(); + } else { + llvm_unreachable("invalid value"); + } + Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; + if (!conditionYieldedVal.getType().isa()) { + // scf.condition was already bufferized. + return conditionYieldedVal.getType().cast(); + } + return bufferization::getBufferType(conditionYieldedVal, options, + fixedTypes); + } + /// Assert that yielded values of an scf.while op are equivalent to their /// corresponding bbArgs. In that case, the buffer relations of the /// corresponding OpResults are "Equivalent". @@ -979,11 +1020,6 @@ yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); - // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized - // together with scf.while.) - if (isa(yieldOp->getParentOp())) - return success(); - SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); @@ -992,15 +1028,20 @@ if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; - // In case of scf::ForOp / scf::IfOp, we may have to cast the value - // before yielding it. - // TODO: Do the same for scf::WhileOp. + // We may have to cast the value before yielding it. if (isa(yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( yieldOp->getParentOp()->getResult(it.index()), options); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); + } else if (auto whileOp = + dyn_cast(yieldOp->getParentOp())) { + FailureOr resultType = bufferization::getBufferType( + whileOp.getBeforeArguments()[it.index()], options); + if (failed(resultType)) + return failure(); + buffer = castBuffer(rewriter, buffer, *resultType); } newResults.push_back(buffer); } else { @@ -1103,6 +1144,7 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { + ConditionOp::attachInterface(*ctx); ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -388,23 +388,23 @@ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w0]], %[[a0]] // CHECK: memref.dealloc %[[w0]] - // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] - // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] - // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] - // CHECK: memref.dealloc %[[a0]] - // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]] + // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]] // CHECK: memref.dealloc %[[a1]] + // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]] + // CHECK: memref.dealloc %[[a0]] // CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]] %condition = tensor.extract %w0[%idx] : tensor<5xi1> scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> } do { ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): // CHECK: } do { - // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>): // CHECK: memref.store %{{.*}}, %[[b0]] - // CHECK: %[[cloned2:.*]] = bufferization.clone %[[b1]] + // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] : memref<5xi1> to memref<5xi1, #{{.*}}> + // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] : memref<5xi1> to memref<5xi1, #{{.*}}> + // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted1]] // CHECK: memref.dealloc %[[b1]] - // CHECK: %[[cloned3:.*]] = bufferization.clone %[[b0]] + // CHECK: %[[cloned3:.*]] = bufferization.clone %[[casted0]] // CHECK: memref.dealloc %[[b0]] // CHECK: scf.yield %[[cloned3]], %[[cloned2]] // CHECK: } @@ -441,25 +441,24 @@ // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[w0]], %[[a0]] // CHECK: memref.dealloc %[[w0]] - // CHECK: %[[casted1:.*]] = memref.cast %[[a1]] - // CHECK: %[[casted0:.*]] = memref.cast %[[a0]] - // CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]] - // CHECK: memref.dealloc %[[a0]] - // CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]] + // CHECK: %[[cloned1:.*]] = bufferization.clone %[[a1]] // CHECK: memref.dealloc %[[a1]] + // CHECK: %[[cloned0:.*]] = bufferization.clone %[[a0]] + // CHECK: memref.dealloc %[[a0]] // CHECK: scf.condition(%[[condition]]) %[[cloned1]], %[[cloned0]] %condition = tensor.extract %w0[%idx] : tensor<5xi1> scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1> } do { ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>): // CHECK: } do { - // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1, #{{.*}}>, %[[b1:.*]]: memref<5xi1, #{{.*}}): + // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>): // CHECK: memref.store %{{.*}}, %[[b0]] // CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[b1]], %[[a3]] // CHECK: memref.dealloc %[[b1]] // CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1> // CHECK: memref.copy %[[b0]], %[[a2]] + // CHECK: memref.dealloc %[[b0]] // CHECK: %[[casted3:.*]] = memref.cast %[[a3]] // CHECK: %[[casted2:.*]] = memref.cast %[[a2]] // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]] @@ -764,3 +763,32 @@ %x = tensor.extract %r[%c1] : tensor return %x : f32 } + +// ----- + +// We just check that this example bufferizes to valid IR. + +// CHECK-LABEL: func @scf_while_buffer_type_mismatch +func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %cst = arith.constant 5.5 : f32 + %0 = bufferization.alloc_tensor(%sz) : tensor + %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor to tensor + // init_arg and iter_arg have different buffer types. This must be resolved + // with casts. + %r = scf.while (%t = %e2) : (tensor) -> (tensor) { + %c = "test.condition"() : () -> (i1) + %s = "test.dummy"() : () -> (index) + %e = tensor.extract_slice %t[1][%s][1] : tensor to tensor + scf.condition(%c) %e : tensor + } do { + ^bb0(%b0: tensor): + %s2 = "test.dummy"() : () -> (index) + %n = tensor.insert %cst into %b0[%s2] : tensor + scf.yield %n : tensor + } + %x = tensor.extract %r[%c1] : tensor + return %x : f32 +}