diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -75,13 +75,23 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); + auto oldYieldOp = + cast(executeRegionOp.getRegion().front().getTerminator()); // Compute new result types. SmallVector newResultTypes; - for (Type type : executeRegionOp->getResultTypes()) { + for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { + Type type = it.value(); if (auto tensorType = type.dyn_cast()) { // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, options)); + FailureOr memorySpace = inferOrDefaultMemorySpace( + oldYieldOp.getResults()[it.index()], options); + if (failed(memorySpace)) + return oldYieldOp->emitError( + "could not infer memory space for operand #") + << it.index(); + newResultTypes.push_back( + getMemRefType(tensorType, options, /*layout=*/None, *memorySpace)); } else { newResultTypes.push_back(type); } @@ -182,18 +192,44 @@ return true; } + FailureOr inferMemorySpace(Operation *op, Value value, + const AnalysisState &state) const { + const BufferizationOptions &options = state.getOptions(); + auto ifOp = cast(op); + OpResult opResult = value.cast(); + assert(opResult.getOwner() == op && "value must be an OpResult of op"); + + auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); + FailureOr thenMemorySpace = inferOrDefaultMemorySpace( + thenYieldOp.getResults()[opResult.getResultNumber()], options); + auto elseYieldOp = cast(ifOp.elseBlock()->getTerminator()); + FailureOr elseMemorySpace = inferOrDefaultMemorySpace( + elseYieldOp.getResults()[opResult.getResultNumber()], options); + + if (failed(thenMemorySpace) || failed(elseMemorySpace) || + *thenMemorySpace != *elseMemorySpace) + return failure(); + + return *thenMemorySpace; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { + AnalysisState state(options); auto ifOp = cast(op); - // Compute new types of the bufferized scf.if op. SmallVector newTypes; - for (Type returnType : ifOp->getResultTypes()) { - if (auto tensorType = returnType.dyn_cast()) { + for (Value result : ifOp.getResults()) { + if (auto tensorType = result.getType().dyn_cast()) { // TODO: Infer the result type instead of computing it. - newTypes.push_back(getMemRefType(tensorType, options)); + FailureOr memorySpace = inferMemorySpace(op, result, state); + if (failed(memorySpace)) + return ifOp->emitError("could not infer memory space for result #") + << result.cast().getResultNumber(); + newTypes.push_back( + getMemRefType(tensorType, options, /*layout=*/None, *memorySpace)); } else { - newTypes.push_back(returnType); + newTypes.push_back(result.getType()); } } @@ -496,6 +532,17 @@ return success(); } + FailureOr inferMemorySpace(Operation *op, Value value, + const AnalysisState &state) const { + if (OpResult opResult = value.dyn_cast()) + return bufferization::inferMemorySpace(opResult, state); + + auto forOp = cast(op); + auto bbArg = value.cast(); + return bufferization::inferMemorySpace( + forOp.getIterOperands()[bbArg.getArgNumber()], state.getOptions()); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto forOp = cast(op); @@ -1034,8 +1081,17 @@ WalkResult walkResult = performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { Location loc = insertOp.getLoc(); + // TODO: Infer buffer type instead of computing it. + FailureOr memorySpace = + inferOrDefaultMemorySpace(insertOp.getSource(), options); + if (failed(memorySpace)) { + insertOp->emitError( + "could not infer memory space for source operand"); + return WalkResult::interrupt(); + } Type srcType = getMemRefType( - insertOp.getSource().getType().cast(), options); + insertOp.getSource().getType().cast(), options, + /*layout=*/None, *memorySpace); // ParallelInsertSliceOp bufferizes to a copy. auto srcMemref = b.create( loc, srcType, insertOp.getSource()); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -one-shot-bufferize -split-input-file -verify-diagnostics + +func.func @inconsistent_memory_space_scf_if(%c: i1) -> tensor<10xf32> { + // Yielding tensors with different memory spaces. Such IR cannot be + // bufferized. + %0 = bufferization.alloc_tensor() {bufferization.memory_space = [0]} : tensor<10xf32> + %1 = bufferization.alloc_tensor() {bufferization.memory_space = [1]} : tensor<10xf32> + // expected-error @+2 {{could not infer memory space for result #0}} + // expected-error @+1 {{failed to bufferize op}} + %r = scf.if %c -> tensor<10xf32> { + scf.yield %0 : tensor<10xf32> + } else { + scf.yield %1 : tensor<10xf32> + } + func.return %r : tensor<10xf32> +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -633,3 +633,83 @@ } return %0 : tensor<8x8xf32> } + +// ----- + +// CHECK-LABEL: func @scf_if_memory_space +func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32) +{ + %c0 = arith.constant 0 : index + // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1> + %0 = bufferization.alloc_tensor() {bufferization.memory_space = [1]} : tensor<5xf32> + // CHECK: scf.if %{{.*}} -> (memref<5xf32, #{{.*}}, 1>) { + %1 = scf.if %c -> tensor<5xf32> { + // TODO: These casts could be avoided. + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<5xf32, 1> to memref<5xf32, #{{.*}}, 1> + // CHECK: %[[cloned:.*]] = bufferization.clone %[[casted]] + // CHECK: scf.yield %[[cloned]] + scf.yield %0 : tensor<5xf32> + } else { + // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1> + // CHECK: memref.store %{{.*}}, %[[alloc2]] + // CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]] : memref<5xf32, 1> to memref<5xf32, #{{.*}}, 1> + // CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]] + // CHECK: memref.dealloc %[[alloc2]] + // CHECK: scf.yield %[[cloned2]] + %2 = tensor.insert %f into %0[%c0] : tensor<5xf32> + scf.yield %2 : tensor<5xf32> + } + %r0 = tensor.extract %0[%c0] : tensor<5xf32> + %r1 = tensor.extract %1[%c0] : tensor<5xf32> + return %r0, %r1 : f32, f32 +} + +// ----- + +// CHECK-LABEL: func @scf_execute_region_memory_space +// CHECK: memref.alloc() {{.*}} : memref<5xf32, 1> +// CHECK: memref.store +// CHECK: memref.load +// CHECK: memref.dealloc +func.func @scf_execute_region_memory_space(%f: f32) -> f32 { + %c0 = arith.constant 0 : index + %0 = scf.execute_region -> tensor<5xf32> { + %1 = bufferization.alloc_tensor() {bufferization.memory_space = [1]} : tensor<5xf32> + %2 = tensor.insert %f into %1[%c0] : tensor<5xf32> + scf.yield %2 : tensor<5xf32> + } + %r = tensor.extract %0[%c0] : tensor<5xf32> + return %r : f32 +} + +// ----- + +// Additional allocs are inserted in the loop body. We just check that all +// allocs have the correct memory space. + +// CHECK-LABEL: func @scf_for_swapping_yields_memory_space +func.func @scf_for_swapping_yields_memory_space( + %sz: index, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) + -> (f32, f32) +{ + // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref + // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref + %A = bufferization.alloc_tensor(%sz) {bufferization.memory_space = [1]} : tensor + %B = bufferization.alloc_tensor(%sz) {bufferization.memory_space = [1]} : tensor + + // CHECK: scf.for {{.*}} { + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref + // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + // Yield tensors in different order. + scf.yield %ttB, %ttA : tensor, tensor + } + // CHECK: } + %f0 = tensor.extract %r0#0[%step] : tensor + %f1 = tensor.extract %r0#1[%step] : tensor + return %f0, %f1: f32, f32 +}