Index: mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -50,6 +50,16 @@ LogicalResult matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + // Update the branch to merge block to pass the arguments that will escape + // the region. + auto br = + cast(rewriter.getInsertionBlock()->getTerminator()); + SmallVector args(br.getBlockArguments()); + args.append(operands.begin(), operands.end()); + rewriter.setInsertionPoint(br); + rewriter.create(terminatorOp.getLoc(), br.getTarget(), + args); + rewriter.eraseOp(br); rewriter.eraseOp(terminatorOp); return success(); } @@ -128,6 +138,38 @@ } // namespace +// Helper function to handle the results of scf operations. +// We store values returned by the IfOp in an allocation as +// spv.loop/spv.if/spv.selection currently doesn't have result. Right after the loop we +// load the value from the allocation and use it as the ForOp result. +template +static void ReplaceSCFOutputValue(ScfOp scfOp, OpTy newOp, + SPIRVTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter, + Block *mergeBlock) { + + auto loc = scfOp.getLoc(); + SmallVector resultValue; + for (auto result : scfOp.results()) { + auto convertedType = typeConverter.convertType(result.getType()); + auto pointerType = + spirv::PointerType::get(convertedType, spirv::StorageClass::Function); + rewriter.setInsertionPoint(newOp); + auto alloc = rewriter.create( + loc, pointerType, + rewriter.getI32IntegerAttr( + static_cast(spirv::StorageClass::Function)), + nullptr); + rewriter.setInsertionPointAfter(newOp); + Value loadResult = rewriter.create(loc, alloc); + resultValue.push_back(loadResult); + rewriter.setInsertionPoint(mergeBlock, mergeBlock->begin()); + auto arg = mergeBlock->addArgument(convertedType); + rewriter.create(loc, alloc, arg); + } + rewriter.replaceOp(scfOp, resultValue); +} + //===----------------------------------------------------------------------===// // scf::ForOp. //===----------------------------------------------------------------------===// @@ -156,6 +198,8 @@ // Create the new induction variable to use. BlockArgument newIndVar = header->addArgument(forOperands.lowerBound().getType()); + for (auto arg : forOperands.initArgs()) + header->addArgument(arg.getType()); Block *body = forOp.getBody(); // Apply signature conversion to the body of the forOp. It has a single block, @@ -164,28 +208,34 @@ TypeConverter::SignatureConversion signatureConverter( body->getNumArguments()); signatureConverter.remapInput(0, newIndVar); + for (unsigned int i = 1; i < body->getNumArguments(); i++) + signatureConverter.remapInput(i, header->getArgument(i)); body = rewriter.applySignatureConversion(&forOp.getLoopBody(), signatureConverter); - // Delete the loop terminator. - rewriter.eraseOp(body->getTerminator()); - // Move the blocks from the forOp into the loopOp. This is the body of the // loopOp. rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), std::next(loopOp.body().begin(), 2)); + SmallVector args(1, forOperands.lowerBound()); + args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.body().front())); - rewriter.create(loc, header, forOperands.lowerBound()); + rewriter.create(loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto mergeBlock = loopOp.getMergeBlock(); auto cmpOp = rewriter.create( loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + args.clear(); + // Forward all the arguments to the merge block except the induction variable. + for (unsigned i = 1; i < header->getNumArguments(); i++) + args.push_back(header->getArgument(i)); + rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + loc, cmpOp, body, ArrayRef(), mergeBlock, args); // Generate instructions to increment the step of the induction variable and // branch to the header. @@ -197,7 +247,7 @@ loc, newIndVar.getType(), newIndVar, forOperands.step()); rewriter.create(loc, header, updatedIndVar); - rewriter.eraseOp(forOp); + ReplaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, mergeBlock); return success(); } @@ -249,7 +299,7 @@ thenBlock, ArrayRef(), elseBlock, ArrayRef()); - rewriter.eraseOp(ifOp); + ReplaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, mergeBlock); return success(); } Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -475,10 +475,9 @@ return getElementType(type, indices, errorFn); } -/// Returns true if the given `block` only contains one `spv._merge` op. +/// Returns true if the given `block` ends with `spv._merge` op. static inline bool isMergeBlock(Block &block) { - return !block.empty() && std::next(block.begin()) == block.end() && - isa(block.front()); + return !block.empty() && isa(block.back()); } //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -589,9 +589,6 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { - if (storage) - assert(*storage == getStorageClass() && "inconsistent storage class!"); - // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. getPointeeType().cast().getExtensions(extensions, @@ -604,9 +601,6 @@ void PointerType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage) { - if (storage) - assert(*storage == getStorageClass() && "inconsistent storage class!"); - // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. getPointeeType().cast().getCapabilities(capabilities, Index: mlir/test/Conversion/GPUToSPIRV/if.mlir =================================================================== --- mlir/test/Conversion/GPUToSPIRV/if.mlir +++ mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -89,5 +89,76 @@ } gpu.return } + + // CHECK-LABEL: @simple_if_yield + gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32 + // CHECK: spv.Branch ^[[MERGE:.*]](%[[RET1TRUE]], %[[RET2TRUE]] : f32, f32) + // CHECK-NEXT: [[FALSE]]: + // CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32 + // CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32 + // CHECK: spv.Branch ^[[MERGE]](%[[RET1FALSE]], %[[RET2FALSE]] : f32, f32) + // CHECK-NEXT: ^[[MERGE]](%[[RET1:.*]]: f32, %[[RET2:.*]]: f32): + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2]] : f32 + // CHECK: spv._merge + // CHECK-NEXT: } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + // CHECK: spv.Return + %0:2 = scf.if %arg3 -> (f32, f32) { + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + scf.yield %c0, %c1 : f32, f32 + } else { + %c0 = constant 2.0 : f32 + %c1 = constant 3.0 : f32 + scf.yield %c1, %c0 : f32, f32 + } + %i = constant 0 : index + %j = constant 1 : index + store %0#0, %arg2[%i] : memref<10xf32> + store %0#1, %arg2[%j] : memref<10xf32> + gpu.return + } + + // CHECK-LABEL: @simple_if_yield_type_change + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0]>, StorageBuffer>, Function> + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: spv.Branch ^[[MERGE:.*]]({{%.*}} : !spv.ptr [0]>, StorageBuffer>) + // CHECK-NEXT: [[FALSE]]: + // CHECK: spv.Branch ^[[MERGE]]({{%.*}} : !spv.ptr [0]>, StorageBuffer>) + // CHECK-NEXT: ^[[MERGE]](%[[RET:.*]]: !spv.ptr [0]>, StorageBuffer>): + // CHECK: spv.Store "Function" %[[VAR]], %[[RET]] : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv._merge + // CHECK-NEXT: } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0]>, StorageBuffer> + // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 + // CHECK: spv.Return + gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + %i = constant 0 : index + %value = constant 0.0 : f32 + + %0 = scf.if %arg4 -> (memref<10xf32>) { + scf.yield %arg2 : memref<10xf32> + } else { + scf.yield %arg3 : memref<10xf32> + } + store %value, %0[%i] : memref<10xf32> + gpu.return + } } } Index: mlir/test/Conversion/GPUToSPIRV/loop.mlir =================================================================== --- mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ mlir/test/Conversion/GPUToSPIRV/loop.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -convert-gpu-to-spirv -split-input-file %s -o - | FileCheck %s module attributes { gpu.container_module, @@ -51,5 +51,48 @@ } gpu.return } + + + // CHECK-LABEL: @loop_yield + gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + // CHECK: %[[LB:.*]] = spv.constant 4 : i32 + %lb = constant 4 : index + // CHECK: %[[UB:.*]] = spv.constant 42 : i32 + %ub = constant 42 : index + // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 + %step = constant 2 : index + // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32 + %s0 = constant 0.0 : f32 + // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32 + %s1 = constant 1.0 : f32 + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32) + // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]](%[[CARRIED1]], %[[CARRIED2]] : f32, f32) + // CHECK: ^[[BODY]]: + // CHECK: %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32 + // CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 + // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32) + // CHECK: ^[[MERGE]](%[[OUT1:.*]]: f32, %[[OUT2:.*]]: f32) + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[OUT1]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[OUT2]] : f32 + // CHECK: spv._merge + // CHECK: } + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) { + %sn = addf %si, %si : f32 + scf.yield %sn, %sn : f32, f32 + } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + store %result#0, %arg3[%lb] : memref<10xf32> + store %result#1, %arg3[%ub] : memref<10xf32> + gpu.return + } } }