Index: mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h =================================================================== --- mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h +++ mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h @@ -21,11 +21,23 @@ // Owning list of rewriting patterns. class OwningRewritePatternList; class SPIRVTypeConverter; +struct ScfToSPIRVContextImpl; + +struct ScfToSPIRVContext { + ScfToSPIRVContext(); + ~ScfToSPIRVContext(); + + ScfToSPIRVContextImpl *getImpl() { return impl.get(); } + +private: + std::unique_ptr impl; +}; /// Collects a set of patterns to lower from scf.for, scf.if, and /// loop.terminator to CFG operations within the SPIR-V dialect. void populateSCFToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, + ScfToSPIRVContext &scfToSPIRVContext, OwningRewritePatternList &patterns); } // namespace mlir Index: mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -58,9 +58,10 @@ spirv::SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); + ScfToSPIRVContext scfContext; OwningRewritePatternList patterns; populateGPUToSPIRVPatterns(context, typeConverter, patterns); - populateSCFToSPIRVPatterns(context, typeConverter, patterns); + populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); if (failed(applyFullConversion(kernelModules, *target, patterns))) Index: mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -18,12 +18,44 @@ using namespace mlir; +namespace mlir { +struct ScfToSPIRVContextImpl { + // Map between the spirv region control flow operation (spv.loop or + // spv.selection) to the VariableOp created to store the region results. The + // order of the VariableOp matches the order of the results. + DenseMap> outputVars; +}; +} // namespace mlir + +/// We use ScfToSPIRVContext to store information about the lowering of the scf +/// region that need to be used later on. When we lower scf.for/scf.if we create +/// VariableOp to store the results. We need to keep track of the VariableOp +/// created as we need to insert stores into them when lowering Yield. Those +/// StoreOp cannot be created earlier as they may use a different type than +/// yield operands. +ScfToSPIRVContext::ScfToSPIRVContext() { + impl = std::make_unique(); +} +ScfToSPIRVContext::~ScfToSPIRVContext() = default; + namespace { +/// Common class for all vector to GPU patterns. +template +class SCFToSPIRVPattern : public SPIRVOpLowering { +public: + SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter, + ScfToSPIRVContextImpl *scfToSPIRVContext) + : SPIRVOpLowering::SPIRVOpLowering(context, converter), + scfToSPIRVContext(scfToSPIRVContext) {} + +protected: + ScfToSPIRVContextImpl *scfToSPIRVContext; +}; /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. -class ForOpConversion final : public SPIRVOpLowering { +class ForOpConversion final : public SCFToSPIRVPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult matchAndRewrite(scf::ForOp forOp, ArrayRef operands, @@ -32,29 +64,54 @@ /// Pattern to convert a scf::IfOp within kernel functions into /// spirv::SelectionOp. -class IfOpConversion final : public SPIRVOpLowering { +class IfOpConversion final : public SCFToSPIRVPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; -/// Pattern to erase a scf::YieldOp. -class TerminatorOpConversion final : public SPIRVOpLowering { +class TerminatorOpConversion final : public SCFToSPIRVPattern { public: - using SPIRVOpLowering::SPIRVOpLowering; + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.eraseOp(terminatorOp); - return success(); - } + ConversionPatternRewriter &rewriter) const override; }; } // namespace +/// Helper function to replaces SCF op outputs with SPIR-V variable loads. +/// We create VariableOp to handle the results value of the control flow region. +/// spv.loop/spv.selection currently don't yield value. Right after the loop +/// we load the value from the allocation and use it as the SCF op result. +template +static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, + SPIRVTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter, + ScfToSPIRVContextImpl *scfToSPIRVContext) { + + Location loc = scfOp.getLoc(); + auto &allocas = scfToSPIRVContext->outputVars[newOp]; + SmallVector resultValue; + for (Value result : scfOp.results()) { + auto convertedType = typeConverter.convertType(result.getType()); + auto pointerType = + spirv::PointerType::get(convertedType, spirv::StorageClass::Function); + rewriter.setInsertionPoint(newOp); + auto alloc = rewriter.create( + loc, pointerType, spirv::StorageClass::Function, + /*initializer=*/nullptr); + allocas.push_back(alloc); + rewriter.setInsertionPointAfter(newOp); + Value loadResult = rewriter.create(loc, alloc); + resultValue.push_back(loadResult); + } + rewriter.replaceOp(scfOp, resultValue); +} + //===----------------------------------------------------------------------===// // scf::ForOp. //===----------------------------------------------------------------------===// @@ -83,6 +140,8 @@ // Create the new induction variable to use. BlockArgument newIndVar = header->addArgument(forOperands.lowerBound().getType()); + for (Value arg : forOperands.initArgs()) + header->addArgument(arg.getType()); Block *body = forOp.getBody(); // Apply signature conversion to the body of the forOp. It has a single block, @@ -91,29 +150,28 @@ TypeConverter::SignatureConversion signatureConverter( body->getNumArguments()); signatureConverter.remapInput(0, newIndVar); - FailureOr newBody = rewriter.convertRegionTypes( - &forOp.getLoopBody(), typeConverter, &signatureConverter); - if (failed(newBody)) - return failure(); - body = *newBody; - - // Delete the loop terminator. - rewriter.eraseOp(body->getTerminator()); + for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) + signatureConverter.remapInput(i, header->getArgument(i)); + body = rewriter.applySignatureConversion(&forOp.getLoopBody(), + signatureConverter); // Move the blocks from the forOp into the loopOp. This is the body of the // loopOp. rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), std::next(loopOp.body().begin(), 2)); + SmallVector args(1, forOperands.lowerBound()); + args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.body().front())); - rewriter.create(loc, header, forOperands.lowerBound()); + rewriter.create(loc, header, args); // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); auto cmpOp = rewriter.create( loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + rewriter.create( loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); @@ -127,7 +185,8 @@ loc, newIndVar.getType(), newIndVar, forOperands.step()); rewriter.create(loc, header, updatedIndVar); - rewriter.eraseOp(forOp); + replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, + scfToSPIRVContext); return success(); } @@ -179,13 +238,45 @@ thenBlock, ArrayRef(), elseBlock, ArrayRef()); - rewriter.eraseOp(ifOp); + replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, + scfToSPIRVContext); + return success(); +} + +/// Yield is lowered to stores to the VariableOp created during lowering of the +/// parent region. For loops we also need to update the branch looping back to +/// the header with the loop carried values. +LogicalResult TerminatorOpConversion::matchAndRewrite( + scf::YieldOp terminatorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // If the region is return values, store each value into the associated + // VariableOp created during lowering of the parent region. + if (!operands.empty()) { + auto loc = terminatorOp.getLoc(); + auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()]; + assert(allocas.size() == operands.size()); + for (unsigned i = 0, e = operands.size(); i < e; i++) + rewriter.create(loc, allocas[i], operands[i]); + if (isa(terminatorOp.getParentOp())) { + // For loops we also need to update the branch jumping back to the header. + auto br = + cast(rewriter.getInsertionBlock()->getTerminator()); + SmallVector args(br.getBlockArguments()); + args.append(operands.begin(), operands.end()); + rewriter.setInsertionPoint(br); + rewriter.create(terminatorOp.getLoc(), br.getTarget(), + args); + rewriter.eraseOp(br); + } + } + rewriter.eraseOp(terminatorOp); return success(); } void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, + ScfToSPIRVContext &scfToSPIRVContext, OwningRewritePatternList &patterns) { patterns.insert( - context, typeConverter); + context, typeConverter, scfToSPIRVContext.getImpl()); } Index: mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -589,9 +589,6 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { - if (storage) - assert(*storage == getStorageClass() && "inconsistent storage class!"); - // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. getPointeeType().cast().getExtensions(extensions, @@ -604,9 +601,6 @@ void PointerType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage) { - if (storage) - assert(*storage == getStorageClass() && "inconsistent storage class!"); - // Use this pointer type's storage class because this pointer indicates we are // using the pointee type in that specific storage class. getPointeeType().cast().getCapabilities(capabilities, Index: mlir/test/Conversion/GPUToSPIRV/if.mlir =================================================================== --- mlir/test/Conversion/GPUToSPIRV/if.mlir +++ mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -89,5 +89,79 @@ } gpu.return } + // CHECK-LABEL: @simple_if_yield + gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32 + // CHECK: spv.Branch ^[[MERGE:.*]] + // CHECK-NEXT: [[FALSE]]: + // CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32 + // CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32 + // CHECK: spv.Branch ^[[MERGE]] + // CHECK-NEXT: ^[[MERGE]]: + // CHECK: spv._merge + // CHECK-NEXT: } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + // CHECK: spv.Return + %0:2 = scf.if %arg3 -> (f32, f32) { + %c0 = constant 0.0 : f32 + %c1 = constant 1.0 : f32 + scf.yield %c0, %c1 : f32, f32 + } else { + %c0 = constant 2.0 : f32 + %c1 = constant 3.0 : f32 + scf.yield %c1, %c0 : f32, f32 + } + %i = constant 0 : index + %j = constant 1 : index + store %0#0, %arg2[%i] : memref<10xf32> + store %0#1, %arg2[%j] : memref<10xf32> + gpu.return + } + // TODO(thomasraoux): The transformation should only be legal if + // VariablePointer capability is supported. This test is still useful to + // make sure we can handle scf op result with type change. + // CHECK-LABEL: @simple_if_yield_type_change + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0]>, StorageBuffer>, Function> + // CHECK: spv.selection { + // CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]] + // CHECK-NEXT: [[TRUE]]: + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Branch ^[[MERGE:.*]] + // CHECK-NEXT: [[FALSE]]: + // CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Branch ^[[MERGE]] + // CHECK-NEXT: ^[[MERGE]]: + // CHECK: spv._merge + // CHECK-NEXT: } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr [0]>, StorageBuffer> + // CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr [0]>, StorageBuffer> + // CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32 + // CHECK: spv.Return + gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + %i = constant 0 : index + %value = constant 0.0 : f32 + %0 = scf.if %arg4 -> (memref<10xf32>) { + scf.yield %arg2 : memref<10xf32> + } else { + scf.yield %arg3 : memref<10xf32> + } + store %value, %0[%i] : memref<10xf32> + gpu.return + } } } Index: mlir/test/Conversion/GPUToSPIRV/loop.mlir =================================================================== --- mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ mlir/test/Conversion/GPUToSPIRV/loop.mlir @@ -51,5 +51,48 @@ } gpu.return } + + + // CHECK-LABEL: @loop_yield + gpu.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel + attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} { + // CHECK: %[[LB:.*]] = spv.constant 4 : i32 + %lb = constant 4 : index + // CHECK: %[[UB:.*]] = spv.constant 42 : i32 + %ub = constant 42 : index + // CHECK: %[[STEP:.*]] = spv.constant 2 : i32 + %step = constant 2 : index + // CHECK: %[[INITVAR1:.*]] = spv.constant 0.000000e+00 : f32 + %s0 = constant 0.0 : f32 + // CHECK: %[[INITVAR2:.*]] = spv.constant 1.000000e+00 : f32 + %s1 = constant 1.0 : f32 + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[LB]], %[[INITVAR1]], %[[INITVAR2]] : i32, f32, f32) + // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32, %[[CARRIED1:.*]]: f32, %[[CARRIED2:.*]]: f32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] + // CHECK: ^[[BODY]]: + // CHECK: %[[UPDATED:.*]] = spv.FAdd %[[CARRIED1]], %[[CARRIED1]] : f32 + // CHECK-DAG: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32 + // CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[UPDATED]] : f32 + // CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[UPDATED]] : f32 + // CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]], %[[UPDATED]], %[[UPDATED]] : i32, f32, f32) + // CHECK: ^[[MERGE]]: + // CHECK: spv._merge + // CHECK: } + %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%si = %s0, %sj = %s1) -> (f32, f32) { + %sn = addf %si, %si : f32 + scf.yield %sn, %sn : f32, f32 + } + // CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32 + // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 + // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 + store %result#0, %arg3[%lb] : memref<10xf32> + store %result#1, %arg3[%ub] : memref<10xf32> + gpu.return + } } }