Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[MPS] Add logit op
  • Loading branch information
qqaatw committed Feb 20, 2023
commit a3895c8850863fb3cbbc5a535848b20df2b1d316
158 changes: 158 additions & 0 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,164 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
});
}

void logit_mps_impl(const Tensor& self, c10::optional<double> eps, Tensor& output, const std::string op_name) {
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";

mps::unary_op(self, output, key,
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];

MPSGraphTensor* logitInputTensor;
if (eps.has_value()) {
MPSGraphTensor *lowTensor = [mpsGraph constantWithScalar:eps.value()
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
logitInputTensor = [mpsGraph clampWithTensor:inputTensor
minValueTensor:lowTensor
maxValueTensor:highTensor
name:nil];
} else {
logitInputTensor = inputTensor;
}

MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: logitInputTensor
name: nil];
MPSGraphTensor *outputTensor = [mpsGraph divisionWithPrimaryTensor:logitInputTensor
secondaryTensor:oneMinusInputTensor
name:nil];

return [mpsGraph logarithmWithTensor:outputTensor
name:nil];
});
}

Tensor& logit_out_mps(const Tensor& self,
c10::optional<double> eps,
Tensor& result) {
logit_mps_impl(self, eps, result, "logit_out_mps");
return result;
}

Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
Tensor result = at::native::empty_mps(
self.sizes(),
ScalarType::Float,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
logit_mps_impl(self, eps, result, "logit_mps");
return result;
}

TORCH_IMPL_FUNC(logit_backward_out_mps) (
const Tensor& grad_output,
const Tensor& input,
c10::optional<double> eps,
const Tensor& grad_input)
{
using namespace mps;

// Empty output
if(grad_input.numel() == 0)
return;

double eps_ = eps ? eps.value() : -1.0;

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *gradOutputTensor_ = nil;
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
std::string key = "logit_backward_out_mps:" + getTensorsStringKey({grad_output, input}) + ":" +
"[" + (eps.has_value() ? std::to_string(eps.value()) : "-1" ) + "]";

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_
shape:@[@1]
dataType:inputTensor.dataType];
MPSGraphTensor *inputLessThanLowPredicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *highTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: lowTensor
name: nil];
MPSGraphTensor *inputGreaterThanHighPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor: inputTensor
secondaryTensor: highTensor
name: nil];
MPSGraphTensor* outOfIntervalTensor = [mpsGraph logicalORWithPrimaryTensor: inputLessThanLowPredicateTensor
secondaryTensor: inputGreaterThanHighPredicateTensor
name: nil];
MPSGraphTensor *oneMinusInputTensor = [mpsGraph subtractionWithPrimaryTensor: oneTensor
secondaryTensor: inputTensor
name: nil];
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:oneMinusInputTensor
name:nil];
outputTensor = [mpsGraph divisionWithPrimaryTensor:gradOutputTensor
secondaryTensor:outputTensor
name:nil];
outputTensor = [mpsGraph selectWithPredicateTensor: outOfIntervalTensor
truePredicateTensor: zeroTensor
falsePredicateTensor: outputTensor
name: nil];

newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->outputTensor_, grad_input);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(),
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}



TORCH_IMPL_FUNC(cumsum_out_mps)
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4926,6 +4926,7 @@
variants: function, method
dispatch:
CPU, CUDA: logit
MPS: logit_mps
tags: pointwise

- func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)
Expand All @@ -4937,6 +4938,7 @@
- func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: logit_out
MPS: logit_out_mps
tags: pointwise

- func: sin(Tensor self) -> Tensor
Expand Down Expand Up @@ -12130,6 +12132,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: logit_backward_out
MPS: logit_backward_out_mps
tags: pointwise

- func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor
Expand Down
2 changes: 2 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9223,6 +9223,7 @@ class TestConsistency(TestCaseMPS):
'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'],
'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -9472,6 +9473,7 @@ class TestConsistency(TestCaseMPS):
'log_softmax': ['f32'],
'logaddexp': ['f32'],
'logical_not': ['f16', 'f32'],
'logit': ['f16', 'f32'],
'logspace': ['f32'],
'matmul': ['f32'],
'mm': ['f32'],
Expand Down