diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h @@ -22,6 +22,7 @@ namespace spirv { class ModuleOp; +class TargetEnvAttr; //===----------------------------------------------------------------------===// // Passes @@ -69,8 +70,9 @@ /// Creates an operation pass that unifies access of multiple aliased resources /// into access of one single resource. +using GetTargetEnvFn = std::function; std::unique_ptr> -createUnifyAliasedResourcePass(); +createUnifyAliasedResourcePass(GetTargetEnvFn getTargetEnv = nullptr); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -223,7 +224,8 @@ } if (auto addressOp = dyn_cast(op)) { auto moduleOp = addressOp->getParentOfType(); - auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()); + auto *varOp = + SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()); return shouldUnify(varOp); } @@ -517,8 +519,8 @@ Value value = adaptor.getValue(); if (srcElemType != dstElemType) value = rewriter.create(loc, dstElemType, value); - rewriter.replaceOpWithNewOp(storeOp, adaptor.getPtr(), value, - storeOp->getAttrs()); + rewriter.replaceOpWithNewOp(storeOp, adaptor.getPtr(), + value, storeOp->getAttrs()); return success(); } }; @@ -532,7 +534,13 @@ : public spirv::impl::SPIRVUnifyAliasedResourcePassBase< UnifyAliasedResourcePass> { public: + explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) + : getTargetEnvFn(std::move(getTargetEnv)) {} + void runOnOperation() override; + +private: + spirv::GetTargetEnvFn getTargetEnvFn; }; } // namespace @@ -540,6 +548,14 @@ spirv::ModuleOp moduleOp = getOperation(); MLIRContext *context = &getContext(); + if (getTargetEnvFn) { + // This pass is actually only needed for targeting Apple GPUs via MoltenVK, + // where we need to translate SPIR-V into MSL. The translation has + // limitations. + if (getTargetEnvFn(moduleOp).getVendorID() != spirv::Vendor::Apple) + return; + } + // Analyze aliased resources first. ResourceAliasAnalysis &analysis = getAnalysis(); @@ -570,6 +586,6 @@ } std::unique_ptr> -spirv::createUnifyAliasedResourcePass() { - return std::make_unique(); +spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) { + return std::make_unique(std::move(getTargetEnv)); }